新增批量大小和幅值放缩设置,优化J峰预测功能

This commit is contained in:
2025-12-16 16:57:30 +08:00
parent cbf871ca8c
commit 1d791320eb
4 changed files with 238 additions and 55 deletions

View File

@ -338,7 +338,8 @@ class MainWindow_detect_Jpeak(QMainWindow):
# 预测峰值
PublicFunc.progressbar_update(self, 2, 3, Constants.DETECT_JPEAK_PREDICTING_PEAK, 10)
self.model.selected_model = Config["DetectMethod"]
result = self.data.predict_Jpeak(self.model)
scale = self.ui.spinBox_scaleValue.value() if self.ui.checkBox_scaleEnable.isChecked() else 0
result = self.data.predict_Jpeak(self.model, batch_size=int(self.ui.comboBox_batchSize.currentText()), scale=scale)
if not result.status:
PublicFunc.text_output(self.ui, "(2/3)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
@ -469,7 +470,7 @@ class Data:
return Result().failure(info=Constants.PREPROCESS_FAILURE +
Constants.FAILURE_REASON["Preprocess_Exception"] + "\n" + format_exc())
def predict_Jpeak(self, model):
def predict_Jpeak(self, model, batch_size=0, scale=0):
if not (Path(model.model_folder_path) / Path(model.selected_model)).exists():
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Model_File_Not_Exist"])
@ -489,7 +490,9 @@ class Data:
Config["IntervalHigh"],
Config["IntervalLow"],
Config["PeaksValue"],
Config["UseCPU"])
Config["UseCPU"],
batch_size,
scale)
except Exception as e:
return Result().failure(info=Constants.DETECT_JPEAK_PREDICT_FAILURE +
Constants.FAILURE_REASON["Predict_Exception"] + "\n" + format_exc())

View File

@ -1,3 +1,4 @@
import torch
from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan
from numpy import min as np_min
from numpy import max as np_max
@ -5,11 +6,12 @@ from torch import FloatTensor, no_grad, load
from torch import device as torch_device
from torch.cuda import is_available, empty_cache
from torch.nn.functional import sigmoid
from torch.utils.data import TensorDataset, DataLoader
from func.BCGDataset import BCG_Operation
from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet
def evaluate(test_data, model,fs,useCPU):
def evaluate(test_data, model,fs,useCPU, batch_size=0, scale=0):
orgBCG = test_data
operation = BCG_Operation()
# 降采样
@ -17,7 +19,12 @@ def evaluate(test_data, model,fs,useCPU):
# plt.figure()
# plt.plot(orgBCG)
# plt.show()
if scale != 0:
orgBCG_std = orgBCG.std()
orgBCG = orgBCG * scale / orgBCG_std
orgBCG = orgBCG.reshape(-1, 1000)
# test dataset
orgData = FloatTensor(orgBCG).unsqueeze(1)
# predict
@ -30,12 +37,32 @@ def evaluate(test_data, model,fs,useCPU):
# if gpu:
# orgData = orgData.cuda()
# model.cuda()
orgData = orgData.to(device)
model = model.to(device)
model.to(device)
orgData_tensor = FloatTensor(orgBCG).unsqueeze(1)
test_dataset = TensorDataset(orgData_tensor)
batch_size = len(test_dataset) if batch_size == 0 else batch_size
#全部数据放在一个batch里评估
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0 # 简单评估时设为 0
)
all_y_prob = []
with no_grad():
y_hat = model(orgData)
y_prob = sigmoid(y_hat)
for i, data_batch in enumerate(test_loader):
# data_batch 是一个包含 (batch_size, 1, 1000) 数据的列表
inputs = data_batch[0].to(device)
# 预测
y_hat_batch = model(inputs)
y_prob_batch = sigmoid(y_hat_batch)
# 收集结果,移回 CPU
all_y_prob.append(y_prob_batch.cpu())
y_prob = torch.cat(all_y_prob, dim=0).view(-1)
beat = (y_prob>0.5).float().view(-1).cpu().data.numpy()
beat_diff = diff(beat)
up_index = argwhere(beat_diff==1)
@ -250,7 +277,7 @@ def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value):
bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value
return bcg
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU):
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU, batch_size=0, scale=0):
model_name = get_model_name(str(model_name))
if model_name == "Fivelayer_Unet":
@ -267,7 +294,7 @@ def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interva
model.eval()
# J峰预测
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU)
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU, batch_size=batch_size, scale=scale)
y_prob = y_prob.cpu().reshape(-1).data.numpy()
predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low)

View File

@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'MainWindow_detect_Jpeak.ui'
##
## Created by: Qt User Interface Compiler version 6.8.2
## Created by: Qt User Interface Compiler version 6.9.2
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@ -16,10 +16,10 @@ from PySide6.QtGui import (QAction, QBrush, QColor, QConicalGradient,
QIcon, QImage, QKeySequence, QLinearGradient,
QPainter, QPalette, QPixmap, QRadialGradient,
QTransform)
from PySide6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDoubleSpinBox,
QGridLayout, QGroupBox, QHBoxLayout, QLabel,
QMainWindow, QPushButton, QRadioButton, QSizePolicy,
QSpacerItem, QSpinBox, QStatusBar, QTextBrowser,
from PySide6.QtWidgets import (QAbstractSpinBox, QApplication, QCheckBox, QComboBox,
QDoubleSpinBox, QGridLayout, QGroupBox, QHBoxLayout,
QLabel, QMainWindow, QPushButton, QRadioButton,
QSizePolicy, QSpinBox, QStatusBar, QTextBrowser,
QVBoxLayout, QWidget)
class Ui_MainWindow_detect_Jpeak(object):
@ -221,26 +221,77 @@ class Ui_MainWindow_detect_Jpeak(object):
self.groupBox_3 = QGroupBox(self.groupBox_args)
self.groupBox_3.setObjectName(u"groupBox_3")
self.verticalLayout_2 = QVBoxLayout(self.groupBox_3)
self.verticalLayout_2.setObjectName(u"verticalLayout_2")
self.gridLayout_3 = QGridLayout(self.groupBox_3)
self.gridLayout_3.setObjectName(u"gridLayout_3")
self.spinBox_scaleValue = QSpinBox(self.groupBox_3)
self.spinBox_scaleValue.setObjectName(u"spinBox_scaleValue")
self.spinBox_scaleValue.setFont(font)
self.spinBox_scaleValue.setButtonSymbols(QAbstractSpinBox.ButtonSymbols.NoButtons)
self.spinBox_scaleValue.setMinimum(10)
self.spinBox_scaleValue.setMaximum(1000)
self.spinBox_scaleValue.setValue(100)
self.gridLayout_3.addWidget(self.spinBox_scaleValue, 1, 1, 1, 1)
self.checkBox_useCPU = QCheckBox(self.groupBox_3)
self.checkBox_useCPU.setObjectName(u"checkBox_useCPU")
self.checkBox_useCPU.setFont(font)
self.verticalLayout_2.addWidget(self.checkBox_useCPU)
self.label_7 = QLabel(self.groupBox_3)
self.label_7.setObjectName(u"label_7")
self.label_7.setFont(font)
self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.verticalLayout_2.addWidget(self.label_7)
self.gridLayout_3.addWidget(self.checkBox_useCPU, 0, 0, 1, 1)
self.comboBox_model = QComboBox(self.groupBox_3)
self.comboBox_model.setObjectName(u"comboBox_model")
sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
sizePolicy2.setHorizontalStretch(0)
sizePolicy2.setVerticalStretch(0)
sizePolicy2.setHeightForWidth(self.comboBox_model.sizePolicy().hasHeightForWidth())
self.comboBox_model.setSizePolicy(sizePolicy2)
self.comboBox_model.setFont(font)
self.verticalLayout_2.addWidget(self.comboBox_model)
self.gridLayout_3.addWidget(self.comboBox_model, 4, 1, 1, 1)
self.checkBox_scaleEnable = QCheckBox(self.groupBox_3)
self.checkBox_scaleEnable.setObjectName(u"checkBox_scaleEnable")
self.checkBox_scaleEnable.setFont(font)
self.gridLayout_3.addWidget(self.checkBox_scaleEnable, 1, 0, 1, 1)
self.label_8 = QLabel(self.groupBox_3)
self.label_8.setObjectName(u"label_8")
sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred)
sizePolicy3.setHorizontalStretch(0)
sizePolicy3.setVerticalStretch(0)
sizePolicy3.setHeightForWidth(self.label_8.sizePolicy().hasHeightForWidth())
self.label_8.setSizePolicy(sizePolicy3)
self.label_8.setFont(font)
self.label_8.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.gridLayout_3.addWidget(self.label_8, 2, 0, 1, 1)
self.comboBox_batchSize = QComboBox(self.groupBox_3)
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.addItem("")
self.comboBox_batchSize.setObjectName(u"comboBox_batchSize")
self.comboBox_batchSize.setFont(font)
self.gridLayout_3.addWidget(self.comboBox_batchSize, 2, 1, 1, 1)
self.label_7 = QLabel(self.groupBox_3)
self.label_7.setObjectName(u"label_7")
sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Preferred)
sizePolicy4.setHorizontalStretch(0)
sizePolicy4.setVerticalStretch(0)
sizePolicy4.setHeightForWidth(self.label_7.sizePolicy().hasHeightForWidth())
self.label_7.setSizePolicy(sizePolicy4)
self.label_7.setFont(font)
self.label_7.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.gridLayout_3.addWidget(self.label_7, 4, 0, 1, 1)
self.verticalLayout_5.addWidget(self.groupBox_3)
@ -253,10 +304,6 @@ class Ui_MainWindow_detect_Jpeak(object):
self.verticalLayout.addWidget(self.groupBox_args)
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding)
self.verticalLayout.addItem(self.verticalSpacer)
self.horizontalLayout_3 = QHBoxLayout()
self.horizontalLayout_3.setObjectName(u"horizontalLayout_3")
self.pushButton_view = QPushButton(self.groupBox_left)
@ -292,9 +339,8 @@ class Ui_MainWindow_detect_Jpeak(object):
self.verticalLayout.setStretch(0, 1)
self.verticalLayout.setStretch(1, 7)
self.verticalLayout.setStretch(2, 4)
self.verticalLayout.setStretch(3, 1)
self.verticalLayout.setStretch(4, 5)
self.verticalLayout.setStretch(2, 1)
self.verticalLayout.setStretch(3, 5)
self.gridLayout.addWidget(self.groupBox_left, 0, 0, 1, 1)
@ -330,7 +376,17 @@ class Ui_MainWindow_detect_Jpeak(object):
self.label_4.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"~", None))
self.groupBox_3.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u6a21\u578b\u8bbe\u7f6e", None))
self.checkBox_useCPU.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5f3a\u5236\u4f7f\u7528CPU", None))
self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9", None))
self.checkBox_scaleEnable.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u5e45\u503c\u653e\u7f29", None))
self.label_8.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"Batch Size:", None))
self.comboBox_batchSize.setItemText(0, QCoreApplication.translate("MainWindow_detect_Jpeak", u"0", None))
self.comboBox_batchSize.setItemText(1, QCoreApplication.translate("MainWindow_detect_Jpeak", u"64", None))
self.comboBox_batchSize.setItemText(2, QCoreApplication.translate("MainWindow_detect_Jpeak", u"128", None))
self.comboBox_batchSize.setItemText(3, QCoreApplication.translate("MainWindow_detect_Jpeak", u"256", None))
self.comboBox_batchSize.setItemText(4, QCoreApplication.translate("MainWindow_detect_Jpeak", u"512", None))
self.comboBox_batchSize.setItemText(5, QCoreApplication.translate("MainWindow_detect_Jpeak", u"1024", None))
self.comboBox_batchSize.setItemText(6, QCoreApplication.translate("MainWindow_detect_Jpeak", u"2048", None))
self.label_7.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u68c0\u6d4b\u6a21\u578b\u9009\u62e9\uff1a", None))
self.pushButton_view.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u67e5\u770b\u7ed3\u679c", None))
self.pushButton_save.setText(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u4fdd\u5b58\u7ed3\u679c", None))
self.groupBox.setTitle(QCoreApplication.translate("MainWindow_detect_Jpeak", u"\u65e5\u5fd7", None))

View File

@ -56,7 +56,7 @@
<property name="title">
<string>BCG的J峰算法定位</string>
</property>
<layout class="QVBoxLayout" name="verticalLayout" stretch="1,7,4,1,5">
<layout class="QVBoxLayout" name="verticalLayout" stretch="1,7,1,5">
<item>
<layout class="QHBoxLayout" name="horizontalLayout_4">
<item>
@ -321,8 +321,29 @@
<property name="title">
<string>模型设置</string>
</property>
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<layout class="QGridLayout" name="gridLayout_3">
<item row="1" column="1">
<widget class="QSpinBox" name="spinBox_scaleValue">
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
<property name="buttonSymbols">
<enum>QAbstractSpinBox::ButtonSymbols::NoButtons</enum>
</property>
<property name="minimum">
<number>10</number>
</property>
<property name="maximum">
<number>1000</number>
</property>
<property name="value">
<number>100</number>
</property>
</widget>
</item>
<item row="0" column="0">
<widget class="QCheckBox" name="checkBox_useCPU">
<property name="font">
<font>
@ -334,48 +355,124 @@
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="label_7">
<item row="4" column="1">
<widget class="QComboBox" name="comboBox_model">
<property name="sizePolicy">
<sizepolicy hsizetype="Preferred" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QCheckBox" name="checkBox_scaleEnable">
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
<property name="text">
<string>检测模型选择</string>
<string>幅值放缩</string>
</property>
</widget>
</item>
<item row="2" column="0">
<widget class="QLabel" name="label_8">
<property name="sizePolicy">
<sizepolicy hsizetype="Minimum" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
<property name="text">
<string>Batch Size:</string>
</property>
<property name="alignment">
<set>Qt::AlignmentFlag::AlignCenter</set>
</property>
</widget>
</item>
<item>
<widget class="QComboBox" name="comboBox_model">
<item row="2" column="1">
<widget class="QComboBox" name="comboBox_batchSize">
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
</widget>
</item>
</layout>
</widget>
</item>
</layout>
</widget>
<item>
<property name="text">
<string>0</string>
</property>
</item>
<item>
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Orientation::Vertical</enum>
<property name="text">
<string>64</string>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</item>
<item>
<property name="text">
<string>128</string>
</property>
</spacer>
</item>
<item>
<property name="text">
<string>256</string>
</property>
</item>
<item>
<property name="text">
<string>512</string>
</property>
</item>
<item>
<property name="text">
<string>1024</string>
</property>
</item>
<item>
<property name="text">
<string>2048</string>
</property>
</item>
</widget>
</item>
<item row="4" column="0">
<widget class="QLabel" name="label_7">
<property name="sizePolicy">
<sizepolicy hsizetype="Fixed" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="font">
<font>
<pointsize>12</pointsize>
</font>
</property>
<property name="text">
<string>检测模型选择:</string>
</property>
<property name="alignment">
<set>Qt::AlignmentFlag::AlignCenter</set>
</property>
</widget>
</item>
</layout>
</widget>
</item>
</layout>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="horizontalLayout_3">