diff --git a/func/Module_detect_Jpeak.py b/func/Module_detect_Jpeak.py
index 2246aae..c7c4cac 100644
--- a/func/Module_detect_Jpeak.py
+++ b/func/Module_detect_Jpeak.py
@@ -338,7 +338,8 @@ class MainWindow_detect_Jpeak(QMainWindow):
# 预测峰值
PublicFunc.progressbar_update(self, 2, 3, Constants.DETECT_JPEAK_PREDICTING_PEAK, 10)
self.model.selected_model = Config["DetectMethod"]
- result = self.data.predict_Jpeak(self.model)
+ scale = self.ui.spinBox_scaleValue.value() if self.ui.checkBox_scaleEnable.isChecked() else 0
+ result = self.data.predict_Jpeak(self.model, batch_size=int(self.ui.comboBox_batchSize.currentText()), scale=scale)
if not result.status:
PublicFunc.text_output(self.ui, "(2/3)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
@@ -469,7 +470,7 @@ class Data:
return Result().failure(info=Constants.PREPROCESS_FAILURE +
Constants.FAILURE_REASON["Preprocess_Exception"] + "\n" + format_exc())
- def predict_Jpeak(self, model):
+ def predict_Jpeak(self, model, batch_size=0, scale=0):
if not (Path(model.model_folder_path) / Path(model.selected_model)).exists():
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Model_File_Not_Exist"])
@@ -489,7 +490,9 @@ class Data:
Config["IntervalHigh"],
Config["IntervalLow"],
Config["PeaksValue"],
- Config["UseCPU"])
+ Config["UseCPU"],
+ batch_size,
+ scale)
except Exception as e:
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Predict_Exception"] + "\n" + format_exc())
diff --git a/func/utils/detect_Jpeak.py b/func/utils/detect_Jpeak.py
index ab81263..54b9096 100644
--- a/func/utils/detect_Jpeak.py
+++ b/func/utils/detect_Jpeak.py
@@ -1,3 +1,4 @@
+import torch
from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan
from numpy import min as np_min
from numpy import max as np_max
@@ -5,11 +6,12 @@ from torch import FloatTensor, no_grad, load
from torch import device as torch_device
from torch.cuda import is_available, empty_cache
from torch.nn.functional import sigmoid
+from torch.utils.data import TensorDataset, DataLoader
from func.BCGDataset import BCG_Operation
from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet
-def evaluate(test_data, model,fs,useCPU):
+def evaluate(test_data, model,fs,useCPU, batch_size=0, scale=0):
orgBCG = test_data
operation = BCG_Operation()
# 降采样
@@ -17,7 +19,12 @@ def evaluate(test_data, model,fs,useCPU):
# plt.figure()
# plt.plot(orgBCG)
# plt.show()
+ if scale != 0:
+ orgBCG_std = orgBCG.std()
+ orgBCG = orgBCG * scale / orgBCG_std
+
orgBCG = orgBCG.reshape(-1, 1000)
+
# test dataset
orgData = FloatTensor(orgBCG).unsqueeze(1)
# predict
@@ -30,12 +37,32 @@ def evaluate(test_data, model,fs,useCPU):
# if gpu:
# orgData = orgData.cuda()
# model.cuda()
- orgData = orgData.to(device)
- model = model.to(device)
+ model.to(device)
+ orgData_tensor = FloatTensor(orgBCG).unsqueeze(1)
+ test_dataset = TensorDataset(orgData_tensor)
+ batch_size = len(test_dataset) if batch_size == 0 else batch_size
+ #全部数据放在一个batch里评估
+ test_loader = DataLoader(
+ dataset=test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=0 # 简单评估时设为 0
+ )
+ all_y_prob = []
with no_grad():
- y_hat = model(orgData)
- y_prob = sigmoid(y_hat)
+ for i, data_batch in enumerate(test_loader):
+ # data_batch 是一个包含 (batch_size, 1, 1000) 数据的列表
+ inputs = data_batch[0].to(device)
+
+ # 预测
+ y_hat_batch = model(inputs)
+ y_prob_batch = sigmoid(y_hat_batch)
+
+ # 收集结果,移回 CPU
+ all_y_prob.append(y_prob_batch.cpu())
+
+ y_prob = torch.cat(all_y_prob, dim=0).view(-1)
beat = (y_prob>0.5).float().view(-1).cpu().data.numpy()
beat_diff = diff(beat)
up_index = argwhere(beat_diff==1)
@@ -250,7 +277,7 @@ def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value):
bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value
return bcg
-def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU):
+def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU, batch_size=0, scale=0):
model_name = get_model_name(str(model_name))
if model_name == "Fivelayer_Unet":
@@ -267,7 +294,7 @@ def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interva
model.eval()
# J峰预测
- beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU)
+ beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU, batch_size=batch_size, scale=scale)
y_prob = y_prob.cpu().reshape(-1).data.numpy()
predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low)
diff --git a/ui/MainWindow/MainWindow_detect_Jpeak.py b/ui/MainWindow/MainWindow_detect_Jpeak.py
index 1c44af5..704cec7 100644
--- a/ui/MainWindow/MainWindow_detect_Jpeak.py
+++ b/ui/MainWindow/MainWindow_detect_Jpeak.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'MainWindow_detect_Jpeak.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.2
+## Created by: Qt User Interface Compiler version 6.9.2
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@@ -16,10 +16,10 @@ from PySide6.QtGui import (QAction, QBrush, QColor, QConicalGradient,
QIcon, QImage, QKeySequence, QLinearGradient,
QPainter, QPalette, QPixmap, QRadialGradient,
QTransform)
-from PySide6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDoubleSpinBox,
- QGridLayout, QGroupBox, QHBoxLayout, QLabel,
- QMainWindow, QPushButton, QRadioButton, QSizePolicy,
- QSpacerItem, QSpinBox, QStatusBar, QTextBrowser,
+from PySide6.QtWidgets import (QAbstractSpinBox, QApplication, QCheckBox, QComboBox,
+ QDoubleSpinBox, QGridLayout, QGroupBox, QHBoxLayout,
+ QLabel, QMainWindow, QPushButton, QRadioButton,
+ QSizePolicy, QSpinBox, QStatusBar, QTextBrowser,
QVBoxLayout, QWidget)
class Ui_MainWindow_detect_Jpeak(object):
@@ -221,26 +221,77 @@ class Ui_MainWindow_detect_Jpeak(object):
self.groupBox_3 = QGroupBox(self.groupBox_args)
self.groupBox_3.setObjectName(u"groupBox_3")
- self.verticalLayout_2 = QVBoxLayout(self.groupBox_3)
- self.verticalLayout_2.setObjectName(u"verticalLayout_2")
+ self.gridLayout_3 = QGridLayout(self.groupBox_3)
+ self.gridLayout_3.setObjectName(u"gridLayout_3")
+ self.spinBox_scaleValue = QSpinBox(self.groupBox_3)
+ self.spinBox_scaleValue.setObjectName(u"spinBox_scaleValue")
+ self.spinBox_scaleValue.setFont(font)
+ self.spinBox_scaleValue.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons)
+ self.spinBox_scaleValue.setMinimum(10)
+ self.spinBox_scaleValue.setMaximum(1000)
+ self.spinBox_scaleValue.setValue(100)
+
+ self.gridLayout_3.addWidget(self.spinBox_scaleValue, 1, 1, 1, 1)
+
self.checkBox_useCPU = QCheckBox(self.groupBox_3)
self.checkBox_useCPU.setObjectName(u"checkBox_useCPU")
self.checkBox_useCPU.setFont(font)
- self.verticalLayout_2.addWidget(self.checkBox_useCPU)
-
- self.label_7 = QLabel(self.groupBox_3)
- self.label_7.setObjectName(u"label_7")
- self.label_7.setFont(font)
- self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter)
-
- self.verticalLayout_2.addWidget(self.label_7)
+ self.gridLayout_3.addWidget(self.checkBox_useCPU, 0, 0, 1, 1)
self.comboBox_model = QComboBox(self.groupBox_3)
self.comboBox_model.setObjectName(u"comboBox_model")
+ sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
+ sizePolicy2.setHorizontalStretch(0)
+ sizePolicy2.setVerticalStretch(0)
+ sizePolicy2.setHeightForWidth(self.comboBox_model.sizePolicy().hasHeightForWidth())
+ self.comboBox_model.setSizePolicy(sizePolicy2)
self.comboBox_model.setFont(font)
- self.verticalLayout_2.addWidget(self.comboBox_model)
+ self.gridLayout_3.addWidget(self.comboBox_model, 4, 1, 1, 1)
+
+ self.checkBox_scaleEnable = QCheckBox(self.groupBox_3)
+ self.checkBox_scaleEnable.setObjectName(u"checkBox_scaleEnable")
+ self.checkBox_scaleEnable.setFont(font)
+
+ self.gridLayout_3.addWidget(self.checkBox_scaleEnable, 1, 0, 1, 1)
+
+ self.label_8 = QLabel(self.groupBox_3)
+ self.label_8.setObjectName(u"label_8")
+ sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred)
+ sizePolicy3.setHorizontalStretch(0)
+ sizePolicy3.setVerticalStretch(0)
+ sizePolicy3.setHeightForWidth(self.label_8.sizePolicy().hasHeightForWidth())
+ self.label_8.setSizePolicy(sizePolicy3)
+ self.label_8.setFont(font)
+ self.label_8.setAlignment(Qt.AlignmentFlag.AlignCenter)
+
+ self.gridLayout_3.addWidget(self.label_8, 2, 0, 1, 1)
+
+ self.comboBox_batchSize = QComboBox(self.groupBox_3)
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.addItem("")
+ self.comboBox_batchSize.setObjectName(u"comboBox_batchSize")
+ self.comboBox_batchSize.setFont(font)
+
+ self.gridLayout_3.addWidget(self.comboBox_batchSize, 2, 1, 1, 1)
+
+ self.label_7 = QLabel(self.groupBox_3)
+ self.label_7.setObjectName(u"label_7")
+ sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Preferred)
+ sizePolicy4.setHorizontalStretch(0)
+ sizePolicy4.setVerticalStretch(0)
+ sizePolicy4.setHeightForWidth(self.label_7.sizePolicy().hasHeightForWidth())
+ self.label_7.setSizePolicy(sizePolicy4)
+ self.label_7.setFont(font)
+ self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter)
+
+ self.gridLayout_3.addWidget(self.label_7, 4, 0, 1, 1)
self.verticalLayout_5.addWidget(self.groupBox_3)
@@ -253,10 +304,6 @@ class Ui_MainWindow_detect_Jpeak(object):
self.verticalLayout.addWidget(self.groupBox_args)
- self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding)
-
- self.verticalLayout.addItem(self.verticalSpacer)
-
self.horizontalLayout_3 = QHBoxLayout()
self.horizontalLayout_3.setObjectName(u"horizontalLayout_3")
self.pushButton_view = QPushButton(self.groupBox_left)
@@ -292,9 +339,8 @@ class Ui_MainWindow_detect_Jpeak(object):
self.verticalLayout.setStretch(0, 1)
self.verticalLayout.setStretch(1, 7)
- self.verticalLayout.setStretch(2, 4)
- self.verticalLayout.setStretch(3, 1)
- self.verticalLayout.setStretch(4, 5)
+ self.verticalLayout.setStretch(2, 1)
+ self.verticalLayout.setStretch(3, 5)
self.gridLayout.addWidget(self.groupBox_left, 0, 0, 1, 1)
@@ -330,7 +376,17 @@ class Ui_MainWindow_detect_Jpeak(object):
self.label_4.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"~", None))
self.groupBox_3.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u6a21\u578b\u8bbe\u7f6e", None))
self.checkBox_useCPU.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5f3a\u5236\u4f7f\u7528CPU", None))
- self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9", None))
+ self.checkBox_scaleEnable.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5e45\u503c\u653e\u7f29", None))
+ self.label_8.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"Batch Size:", None))
+ self.comboBox_batchSize.setItemText(0, QCoreApplication.translate("MainWindow_detect_Jpeak", u"0", None))
+ self.comboBox_batchSize.setItemText(1, QCoreApplication.translate("MainWindow_detect_Jpeak", u"64", None))
+ self.comboBox_batchSize.setItemText(2, QCoreApplication.translate("MainWindow_detect_Jpeak", u"128", None))
+ self.comboBox_batchSize.setItemText(3, QCoreApplication.translate("MainWindow_detect_Jpeak", u"256", None))
+ self.comboBox_batchSize.setItemText(4, QCoreApplication.translate("MainWindow_detect_Jpeak", u"512", None))
+ self.comboBox_batchSize.setItemText(5, QCoreApplication.translate("MainWindow_detect_Jpeak", u"1024", None))
+ self.comboBox_batchSize.setItemText(6, QCoreApplication.translate("MainWindow_detect_Jpeak", u"2048", None))
+
+ self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9\uff1a", None))
self.pushButton_view.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u67e5\u770b\u7ed3\u679c", None))
self.pushButton_save.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u4fdd\u5b58\u7ed3\u679c", None))
self.groupBox.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u65e5\u5fd7", None))
diff --git a/ui/MainWindow/MainWindow_detect_Jpeak.ui b/ui/MainWindow/MainWindow_detect_Jpeak.ui
index 908c117..8661829 100644
--- a/ui/MainWindow/MainWindow_detect_Jpeak.ui
+++ b/ui/MainWindow/MainWindow_detect_Jpeak.ui
@@ -56,7 +56,7 @@
BCG的J峰算法定位
-
+
-
-
@@ -321,8 +321,29 @@
模型设置
-
-
-
+
+
-
+
+
+
+ 12
+
+
+
+ QAbstractSpinBox::ButtonSymbols::NoButtons
+
+
+ 10
+
+
+ 1000
+
+
+ 100
+
+
+
+ -
@@ -334,28 +355,117 @@
- -
-
+
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 12
+
+
+
+
+ -
+
12
- 检测模型选择
+ 幅值放缩
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 12
+
+
+
+ Batch Size:
Qt::AlignmentFlag::AlignCenter
- -
-
+
-
+
12
+
-
+
+ 0
+
+
+ -
+
+ 64
+
+
+ -
+
+ 128
+
+
+ -
+
+ 256
+
+
+ -
+
+ 512
+
+
+ -
+
+ 1024
+
+
+ -
+
+ 2048
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 12
+
+
+
+ 检测模型选择:
+
+
+ Qt::AlignmentFlag::AlignCenter
+
@@ -364,19 +474,6 @@
- -
-
-
- Qt::Orientation::Vertical
-
-
-
- 20
- 40
-
-
-
-
-
-