新增批量大小和幅值放缩设置,优化J峰预测功能
This commit is contained in:
@ -338,7 +338,8 @@ class MainWindow_detect_Jpeak(QMainWindow):
|
||||
# 预测峰值
|
||||
PublicFunc.progressbar_update(self, 2, 3, Constants.DETECT_JPEAK_PREDICTING_PEAK, 10)
|
||||
self.model.selected_model = Config["DetectMethod"]
|
||||
result = self.data.predict_Jpeak(self.model)
|
||||
scale = self.ui.spinBox_scaleValue.value() if self.ui.checkBox_scaleEnable.isChecked() else 0
|
||||
result = self.data.predict_Jpeak(self.model, batch_size=int(self.ui.comboBox_batchSize.currentText()), scale=scale)
|
||||
if not result.status:
|
||||
PublicFunc.text_output(self.ui, "(2/3)" + result.info, Constants.TIPS_TYPE_ERROR)
|
||||
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
|
||||
@ -469,7 +470,7 @@ class Data:
|
||||
return Result().failure(info=Constants.PREPROCESS_FAILURE +
|
||||
Constants.FAILURE_REASON["Preprocess_Exception"] + "\n" + format_exc())
|
||||
|
||||
def predict_Jpeak(self, model):
|
||||
def predict_Jpeak(self, model, batch_size=0, scale=0):
|
||||
if not (Path(model.model_folder_path) / Path(model.selected_model)).exists():
|
||||
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
|
||||
Constants.FAILURE_REASON["Model_File_Not_Exist"])
|
||||
@ -489,7 +490,9 @@ class Data:
|
||||
Config["IntervalHigh"],
|
||||
Config["IntervalLow"],
|
||||
Config["PeaksValue"],
|
||||
Config["UseCPU"])
|
||||
Config["UseCPU"],
|
||||
batch_size,
|
||||
scale)
|
||||
except Exception as e:
|
||||
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
|
||||
Constants.FAILURE_REASON["Predict_Exception"] + "\n" + format_exc())
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan
|
||||
from numpy import min as np_min
|
||||
from numpy import max as np_max
|
||||
@ -5,11 +6,12 @@ from torch import FloatTensor, no_grad, load
|
||||
from torch import device as torch_device
|
||||
from torch.cuda import is_available, empty_cache
|
||||
from torch.nn.functional import sigmoid
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
|
||||
from func.BCGDataset import BCG_Operation
|
||||
from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet
|
||||
|
||||
def evaluate(test_data, model,fs,useCPU):
|
||||
def evaluate(test_data, model,fs,useCPU, batch_size=0, scale=0):
|
||||
orgBCG = test_data
|
||||
operation = BCG_Operation()
|
||||
# 降采样
|
||||
@ -17,7 +19,12 @@ def evaluate(test_data, model,fs,useCPU):
|
||||
# plt.figure()
|
||||
# plt.plot(orgBCG)
|
||||
# plt.show()
|
||||
if scale != 0:
|
||||
orgBCG_std = orgBCG.std()
|
||||
orgBCG = orgBCG * scale / orgBCG_std
|
||||
|
||||
orgBCG = orgBCG.reshape(-1, 1000)
|
||||
|
||||
# test dataset
|
||||
orgData = FloatTensor(orgBCG).unsqueeze(1)
|
||||
# predict
|
||||
@ -30,12 +37,32 @@ def evaluate(test_data, model,fs,useCPU):
|
||||
# if gpu:
|
||||
# orgData = orgData.cuda()
|
||||
# model.cuda()
|
||||
orgData = orgData.to(device)
|
||||
model = model.to(device)
|
||||
|
||||
model.to(device)
|
||||
orgData_tensor = FloatTensor(orgBCG).unsqueeze(1)
|
||||
test_dataset = TensorDataset(orgData_tensor)
|
||||
batch_size = len(test_dataset) if batch_size == 0 else batch_size
|
||||
#全部数据放在一个batch里评估
|
||||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0 # 简单评估时设为 0
|
||||
)
|
||||
all_y_prob = []
|
||||
with no_grad():
|
||||
y_hat = model(orgData)
|
||||
y_prob = sigmoid(y_hat)
|
||||
for i, data_batch in enumerate(test_loader):
|
||||
# data_batch 是一个包含 (batch_size, 1, 1000) 数据的列表
|
||||
inputs = data_batch[0].to(device)
|
||||
|
||||
# 预测
|
||||
y_hat_batch = model(inputs)
|
||||
y_prob_batch = sigmoid(y_hat_batch)
|
||||
|
||||
# 收集结果,移回 CPU
|
||||
all_y_prob.append(y_prob_batch.cpu())
|
||||
|
||||
y_prob = torch.cat(all_y_prob, dim=0).view(-1)
|
||||
beat = (y_prob>0.5).float().view(-1).cpu().data.numpy()
|
||||
beat_diff = diff(beat)
|
||||
up_index = argwhere(beat_diff==1)
|
||||
@ -250,7 +277,7 @@ def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value):
|
||||
bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value
|
||||
return bcg
|
||||
|
||||
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU):
|
||||
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU, batch_size=0, scale=0):
|
||||
|
||||
model_name = get_model_name(str(model_name))
|
||||
if model_name == "Fivelayer_Unet":
|
||||
@ -267,7 +294,7 @@ def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interva
|
||||
model.eval()
|
||||
|
||||
# J峰预测
|
||||
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU)
|
||||
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU, batch_size=batch_size, scale=scale)
|
||||
y_prob = y_prob.cpu().reshape(-1).data.numpy()
|
||||
|
||||
predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low)
|
||||
|
||||
Reference in New Issue
Block a user