diff --git a/func/Module_detect_Jpeak.py b/func/Module_detect_Jpeak.py index 2246aae..c7c4cac 100644 --- a/func/Module_detect_Jpeak.py +++ b/func/Module_detect_Jpeak.py @@ -338,7 +338,8 @@ class MainWindow_detect_Jpeak(QMainWindow): # 预测峰值 PublicFunc.progressbar_update(self, 2, 3, Constants.DETECT_JPEAK_PREDICTING_PEAK, 10) self.model.selected_model = Config["DetectMethod"] - result = self.data.predict_Jpeak(self.model) + scale = self.ui.spinBox_scaleValue.value() if self.ui.checkBox_scaleEnable.isChecked() else 0 + result = self.data.predict_Jpeak(self.model, batch_size=int(self.ui.comboBox_batchSize.currentText()), scale=scale) if not result.status: PublicFunc.text_output(self.ui, "(2/3)" + result.info, Constants.TIPS_TYPE_ERROR) PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR) @@ -469,7 +470,7 @@ class Data: return Result().failure(info=Constants.PREPROCESS_FAILURE + Constants.FAILURE_REASON["Preprocess_Exception"] + "\n" + format_exc()) - def predict_Jpeak(self, model): + def predict_Jpeak(self, model, batch_size=0, scale=0): if not (Path(model.model_folder_path) / Path(model.selected_model)).exists(): return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE + Constants.FAILURE_REASON["Model_File_Not_Exist"]) @@ -489,7 +490,9 @@ class Data: Config["IntervalHigh"], Config["IntervalLow"], Config["PeaksValue"], - Config["UseCPU"]) + Config["UseCPU"], + batch_size, + scale) except Exception as e: return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE + Constants.FAILURE_REASON["Predict_Exception"] + "\n" + format_exc()) diff --git a/func/utils/detect_Jpeak.py b/func/utils/detect_Jpeak.py index ab81263..54b9096 100644 --- a/func/utils/detect_Jpeak.py +++ b/func/utils/detect_Jpeak.py @@ -1,3 +1,4 @@ +import torch from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan from numpy import min as np_min from numpy import max as np_max @@ -5,11 +6,12 @@ from torch import FloatTensor, no_grad, load from torch import device as torch_device from torch.cuda import is_available, empty_cache from torch.nn.functional import sigmoid +from torch.utils.data import TensorDataset, DataLoader from func.BCGDataset import BCG_Operation from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet -def evaluate(test_data, model,fs,useCPU): +def evaluate(test_data, model,fs,useCPU, batch_size=0, scale=0): orgBCG = test_data operation = BCG_Operation() # 降采样 @@ -17,7 +19,12 @@ def evaluate(test_data, model,fs,useCPU): # plt.figure() # plt.plot(orgBCG) # plt.show() + if scale != 0: + orgBCG_std = orgBCG.std() + orgBCG = orgBCG * scale / orgBCG_std + orgBCG = orgBCG.reshape(-1, 1000) + # test dataset orgData = FloatTensor(orgBCG).unsqueeze(1) # predict @@ -30,12 +37,32 @@ def evaluate(test_data, model,fs,useCPU): # if gpu: # orgData = orgData.cuda() # model.cuda() - orgData = orgData.to(device) - model = model.to(device) + model.to(device) + orgData_tensor = FloatTensor(orgBCG).unsqueeze(1) + test_dataset = TensorDataset(orgData_tensor) + batch_size = len(test_dataset) if batch_size == 0 else batch_size + #全部数据放在一个batch里评估 + test_loader = DataLoader( + dataset=test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0 # 简单评估时设为 0 + ) + all_y_prob = [] with no_grad(): - y_hat = model(orgData) - y_prob = sigmoid(y_hat) + for i, data_batch in enumerate(test_loader): + # data_batch 是一个包含 (batch_size, 1, 1000) 数据的列表 + inputs = data_batch[0].to(device) + + # 预测 + y_hat_batch = model(inputs) + y_prob_batch = sigmoid(y_hat_batch) + + # 收集结果,移回 CPU + all_y_prob.append(y_prob_batch.cpu()) + + y_prob = torch.cat(all_y_prob, dim=0).view(-1) beat = (y_prob>0.5).float().view(-1).cpu().data.numpy() beat_diff = diff(beat) up_index = argwhere(beat_diff==1) @@ -250,7 +277,7 @@ def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value): bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value return bcg -def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU): +def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU, batch_size=0, scale=0): model_name = get_model_name(str(model_name)) if model_name == "Fivelayer_Unet": @@ -267,7 +294,7 @@ def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interva model.eval() # J峰预测 - beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU) + beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU, batch_size=batch_size, scale=scale) y_prob = y_prob.cpu().reshape(-1).data.numpy() predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low) diff --git a/ui/MainWindow/MainWindow_detect_Jpeak.py b/ui/MainWindow/MainWindow_detect_Jpeak.py index 1c44af5..704cec7 100644 --- a/ui/MainWindow/MainWindow_detect_Jpeak.py +++ b/ui/MainWindow/MainWindow_detect_Jpeak.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'MainWindow_detect_Jpeak.ui' ## -## Created by: Qt User Interface Compiler version 6.8.2 +## Created by: Qt User Interface Compiler version 6.9.2 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -16,10 +16,10 @@ from PySide6.QtGui import (QAction, QBrush, QColor, QConicalGradient, QIcon, QImage, QKeySequence, QLinearGradient, QPainter, QPalette, QPixmap, QRadialGradient, QTransform) -from PySide6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDoubleSpinBox, - QGridLayout, QGroupBox, QHBoxLayout, QLabel, - QMainWindow, QPushButton, QRadioButton, QSizePolicy, - QSpacerItem, QSpinBox, QStatusBar, QTextBrowser, +from PySide6.QtWidgets import (QAbstractSpinBox, QApplication, QCheckBox, QComboBox, + QDoubleSpinBox, QGridLayout, QGroupBox, QHBoxLayout, + QLabel, QMainWindow, QPushButton, QRadioButton, + QSizePolicy, QSpinBox, QStatusBar, QTextBrowser, QVBoxLayout, QWidget) class Ui_MainWindow_detect_Jpeak(object): @@ -221,26 +221,77 @@ class Ui_MainWindow_detect_Jpeak(object): self.groupBox_3 = QGroupBox(self.groupBox_args) self.groupBox_3.setObjectName(u"groupBox_3") - self.verticalLayout_2 = QVBoxLayout(self.groupBox_3) - self.verticalLayout_2.setObjectName(u"verticalLayout_2") + self.gridLayout_3 = QGridLayout(self.groupBox_3) + self.gridLayout_3.setObjectName(u"gridLayout_3") + self.spinBox_scaleValue = QSpinBox(self.groupBox_3) + self.spinBox_scaleValue.setObjectName(u"spinBox_scaleValue") + self.spinBox_scaleValue.setFont(font) + self.spinBox_scaleValue.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons) + self.spinBox_scaleValue.setMinimum(10) + self.spinBox_scaleValue.setMaximum(1000) + self.spinBox_scaleValue.setValue(100) + + self.gridLayout_3.addWidget(self.spinBox_scaleValue, 1, 1, 1, 1) + self.checkBox_useCPU = QCheckBox(self.groupBox_3) self.checkBox_useCPU.setObjectName(u"checkBox_useCPU") self.checkBox_useCPU.setFont(font) - self.verticalLayout_2.addWidget(self.checkBox_useCPU) - - self.label_7 = QLabel(self.groupBox_3) - self.label_7.setObjectName(u"label_7") - self.label_7.setFont(font) - self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter) - - self.verticalLayout_2.addWidget(self.label_7) + self.gridLayout_3.addWidget(self.checkBox_useCPU, 0, 0, 1, 1) self.comboBox_model = QComboBox(self.groupBox_3) self.comboBox_model.setObjectName(u"comboBox_model") + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.comboBox_model.sizePolicy().hasHeightForWidth()) + self.comboBox_model.setSizePolicy(sizePolicy2) self.comboBox_model.setFont(font) - self.verticalLayout_2.addWidget(self.comboBox_model) + self.gridLayout_3.addWidget(self.comboBox_model, 4, 1, 1, 1) + + self.checkBox_scaleEnable = QCheckBox(self.groupBox_3) + self.checkBox_scaleEnable.setObjectName(u"checkBox_scaleEnable") + self.checkBox_scaleEnable.setFont(font) + + self.gridLayout_3.addWidget(self.checkBox_scaleEnable, 1, 0, 1, 1) + + self.label_8 = QLabel(self.groupBox_3) + self.label_8.setObjectName(u"label_8") + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.label_8.sizePolicy().hasHeightForWidth()) + self.label_8.setSizePolicy(sizePolicy3) + self.label_8.setFont(font) + self.label_8.setAlignment(Qt.AlignmentFlag.AlignCenter) + + self.gridLayout_3.addWidget(self.label_8, 2, 0, 1, 1) + + self.comboBox_batchSize = QComboBox(self.groupBox_3) + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.addItem("") + self.comboBox_batchSize.setObjectName(u"comboBox_batchSize") + self.comboBox_batchSize.setFont(font) + + self.gridLayout_3.addWidget(self.comboBox_batchSize, 2, 1, 1, 1) + + self.label_7 = QLabel(self.groupBox_3) + self.label_7.setObjectName(u"label_7") + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Preferred) + sizePolicy4.setHorizontalStretch(0) + sizePolicy4.setVerticalStretch(0) + sizePolicy4.setHeightForWidth(self.label_7.sizePolicy().hasHeightForWidth()) + self.label_7.setSizePolicy(sizePolicy4) + self.label_7.setFont(font) + self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter) + + self.gridLayout_3.addWidget(self.label_7, 4, 0, 1, 1) self.verticalLayout_5.addWidget(self.groupBox_3) @@ -253,10 +304,6 @@ class Ui_MainWindow_detect_Jpeak(object): self.verticalLayout.addWidget(self.groupBox_args) - self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) - - self.verticalLayout.addItem(self.verticalSpacer) - self.horizontalLayout_3 = QHBoxLayout() self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") self.pushButton_view = QPushButton(self.groupBox_left) @@ -292,9 +339,8 @@ class Ui_MainWindow_detect_Jpeak(object): self.verticalLayout.setStretch(0, 1) self.verticalLayout.setStretch(1, 7) - self.verticalLayout.setStretch(2, 4) - self.verticalLayout.setStretch(3, 1) - self.verticalLayout.setStretch(4, 5) + self.verticalLayout.setStretch(2, 1) + self.verticalLayout.setStretch(3, 5) self.gridLayout.addWidget(self.groupBox_left, 0, 0, 1, 1) @@ -330,7 +376,17 @@ class Ui_MainWindow_detect_Jpeak(object): self.label_4.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"~", None)) self.groupBox_3.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u6a21\u578b\u8bbe\u7f6e", None)) self.checkBox_useCPU.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5f3a\u5236\u4f7f\u7528CPU", None)) - self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9", None)) + self.checkBox_scaleEnable.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5e45\u503c\u653e\u7f29", None)) + self.label_8.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"Batch Size:", None)) + self.comboBox_batchSize.setItemText(0, QCoreApplication.translate("MainWindow_detect_Jpeak", u"0", None)) + self.comboBox_batchSize.setItemText(1, QCoreApplication.translate("MainWindow_detect_Jpeak", u"64", None)) + self.comboBox_batchSize.setItemText(2, QCoreApplication.translate("MainWindow_detect_Jpeak", u"128", None)) + self.comboBox_batchSize.setItemText(3, QCoreApplication.translate("MainWindow_detect_Jpeak", u"256", None)) + self.comboBox_batchSize.setItemText(4, QCoreApplication.translate("MainWindow_detect_Jpeak", u"512", None)) + self.comboBox_batchSize.setItemText(5, QCoreApplication.translate("MainWindow_detect_Jpeak", u"1024", None)) + self.comboBox_batchSize.setItemText(6, QCoreApplication.translate("MainWindow_detect_Jpeak", u"2048", None)) + + self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9\uff1a", None)) self.pushButton_view.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u67e5\u770b\u7ed3\u679c", None)) self.pushButton_save.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u4fdd\u5b58\u7ed3\u679c", None)) self.groupBox.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u65e5\u5fd7", None)) diff --git a/ui/MainWindow/MainWindow_detect_Jpeak.ui b/ui/MainWindow/MainWindow_detect_Jpeak.ui index 908c117..8661829 100644 --- a/ui/MainWindow/MainWindow_detect_Jpeak.ui +++ b/ui/MainWindow/MainWindow_detect_Jpeak.ui @@ -56,7 +56,7 @@ BCG的J峰算法定位 - + @@ -321,8 +321,29 @@ 模型设置 - - + + + + + + 12 + + + + QAbstractSpinBox::ButtonSymbols::NoButtons + + + 10 + + + 1000 + + + 100 + + + + @@ -334,28 +355,117 @@ - - + + + + + 0 + 0 + + + + + 12 + + + + + + 12 - 检测模型选择 + 幅值放缩 + + + + + + + + 0 + 0 + + + + + 12 + + + + Batch Size: Qt::AlignmentFlag::AlignCenter - - + + 12 + + + 0 + + + + + 64 + + + + + 128 + + + + + 256 + + + + + 512 + + + + + 1024 + + + + + 2048 + + + + + + + + + 0 + 0 + + + + + 12 + + + + 检测模型选择: + + + Qt::AlignmentFlag::AlignCenter + @@ -364,19 +474,6 @@ - - - - Qt::Orientation::Vertical - - - - 20 - 40 - - - -