新增批量大小和幅值放缩设置,优化J峰预测功能

This commit is contained in:
2025-12-16 16:57:30 +08:00
parent cbf871ca8c
commit 1d791320eb
4 changed files with 238 additions and 55 deletions

View File

@ -338,7 +338,8 @@ class MainWindow_detect_Jpeak(QMainWindow):
# 预测峰值
PublicFunc.progressbar_update(self, 2, 3, Constants.DETECT_JPEAK_PREDICTING_PEAK, 10)
self.model.selected_model = Config["DetectMethod"]
result = self.data.predict_Jpeak(self.model)
scale = self.ui.spinBox_scaleValue.value() if self.ui.checkBox_scaleEnable.isChecked() else 0
result = self.data.predict_Jpeak(self.model, batch_size=int(self.ui.comboBox_batchSize.currentText()), scale=scale)
if not result.status:
PublicFunc.text_output(self.ui, "(2/3)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
@ -469,7 +470,7 @@ class Data:
return Result().failure(info=Constants.PREPROCESS_FAILURE +
Constants.FAILURE_REASON["Preprocess_Exception"] + "\n" + format_exc())
def predict_Jpeak(self, model):
def predict_Jpeak(self, model, batch_size=0, scale=0):
if not (Path(model.model_folder_path) / Path(model.selected_model)).exists():
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Model_File_Not_Exist"])
@ -489,7 +490,9 @@ class Data:
Config["IntervalHigh"],
Config["IntervalLow"],
Config["PeaksValue"],
Config["UseCPU"])
Config["UseCPU"],
batch_size,
scale)
except Exception as e:
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Predict_Exception"] + "\n" + format_exc())

View File

@ -1,3 +1,4 @@
import torch
from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan
from numpy import min as np_min
from numpy import max as np_max
@ -5,11 +6,12 @@ from torch import FloatTensor, no_grad, load
from torch import device as torch_device
from torch.cuda import is_available, empty_cache
from torch.nn.functional import sigmoid
from torch.utils.data import TensorDataset, DataLoader
from func.BCGDataset import BCG_Operation
from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet
def evaluate(test_data, model,fs,useCPU):
def evaluate(test_data, model,fs,useCPU, batch_size=0, scale=0):
orgBCG = test_data
operation = BCG_Operation()
# 降采样
@ -17,7 +19,12 @@ def evaluate(test_data, model,fs,useCPU):
# plt.figure()
# plt.plot(orgBCG)
# plt.show()
if scale != 0:
orgBCG_std = orgBCG.std()
orgBCG = orgBCG * scale / orgBCG_std
orgBCG = orgBCG.reshape(-1, 1000)
# test dataset
orgData = FloatTensor(orgBCG).unsqueeze(1)
# predict
@ -30,12 +37,32 @@ def evaluate(test_data, model,fs,useCPU):
# if gpu:
# orgData = orgData.cuda()
# model.cuda()
orgData = orgData.to(device)
model = model.to(device)
model.to(device)
orgData_tensor = FloatTensor(orgBCG).unsqueeze(1)
test_dataset = TensorDataset(orgData_tensor)
batch_size = len(test_dataset) if batch_size == 0 else batch_size
#全部数据放在一个batch里评估
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0 # 简单评估时设为 0
)
all_y_prob = []
with no_grad():
y_hat = model(orgData)
y_prob = sigmoid(y_hat)
for i, data_batch in enumerate(test_loader):
# data_batch 是一个包含 (batch_size, 1, 1000) 数据的列表
inputs = data_batch[0].to(device)
# 预测
y_hat_batch = model(inputs)
y_prob_batch = sigmoid(y_hat_batch)
# 收集结果,移回 CPU
all_y_prob.append(y_prob_batch.cpu())
y_prob = torch.cat(all_y_prob, dim=0).view(-1)
beat = (y_prob>0.5).float().view(-1).cpu().data.numpy()
beat_diff = diff(beat)
up_index = argwhere(beat_diff==1)
@ -250,7 +277,7 @@ def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value):
bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value
return bcg
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU):
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU, batch_size=0, scale=0):
model_name = get_model_name(str(model_name))
if model_name == "Fivelayer_Unet":
@ -267,7 +294,7 @@ def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interva
model.eval()
# J峰预测
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU)
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU, batch_size=batch_size, scale=scale)
y_prob = y_prob.cpu().reshape(-1).data.numpy()
predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low)