Files
Signal_Label_Reborn/func/Module_cut_PSG.py
2025-08-26 15:09:08 +08:00

358 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ast import literal_eval
from gc import collect
from math import floor
from pathlib import Path
from traceback import format_exc
from PySide6.QtWidgets import QMessageBox, QMainWindow, QApplication
from numpy import array
from overrides import overrides
from pandas import read_csv, DataFrame
from yaml import dump, load, FullLoader
from func.utils.ConfigParams import Filename, Params
from func.utils.PublicFunc import PublicFunc
from func.utils.Constants import Constants
from func.utils.Result import Result
from ui.MainWindow.MainWindow_cut_PSG import Ui_MainWindow_cut_PSG
Config = {
}
ButtonState = {
"Default": {
"pushButton_execute": True
},
"Current": {
"pushButton_execute": True
}
}
class MainWindow_cut_PSG(QMainWindow):
def __init__(self):
super(MainWindow_cut_PSG, self).__init__()
self.ui = Ui_MainWindow_cut_PSG()
self.ui.setupUi(self)
self.root_path = None
self.sampID = None
self.__read_config__()
self.data = None
self.ui.textBrowser_info.setStyleSheet("QTextBrowser { background-color: rgb(255, 255, 200); }")
PublicFunc.__styleAllButton__(self, ButtonState)
# 初始化进度条
self.ui.progressbar.setStyleSheet(Constants.PROGRESSBAR_STYLE)
self.progressbar = self.ui.progressbar
self.msgBox = QMessageBox()
self.msgBox.setWindowTitle(Constants.MAINWINDOW_MSGBOX_TITLE)
@overrides
def show(self, root_path, sampID):
super().show()
self.root_path = root_path
self.sampID = sampID
PublicFunc.__resetAllButton__(self, ButtonState)
Config.update({
"Path": {
"InputFolder": str(Path(self.root_path) / Filename.PATH_PSG_TEXT / Path(str(self.sampID))),
"SaveFolder": str(Path(self.root_path) / Filename.PATH_PSG_ALIGNED / Path(str(self.sampID))),
"InputAlignInfo": str(Path(self.root_path) / Filename.PATH_LABEL / Path(str(self.sampID)))
}
})
self.ui.plainTextEdit_channel.setPlainText(', '.join(Config["ChannelInput"].keys()))
self.ui.plainTextEdit_label.setPlainText(', '.join(Config["LabelInput"].keys()))
self.ui.pushButton_execute.clicked.connect(self.__slot_btn_execute__)
@overrides
def closeEvent(self, event):
reply = QMessageBox.question(self, '确认', '确认退出吗?', QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
if reply == QMessageBox.Yes:
PublicFunc.__disableAllButton__(self, ButtonState)
PublicFunc.statusbar_show_msg(self, PublicFunc.format_status_msg(Constants.SHUTTING_DOWN))
QApplication.processEvents()
# 释放资源
del self.data
self.deleteLater()
collect()
event.accept()
else:
event.ignore()
def __reset__(self):
ButtonState["Current"].update(ButtonState["Default"].copy())
def __read_config__(self):
if not Path(Params.CUT_PSG_CONFIG_FILE_PATH).exists():
with open(Params.CUT_PSG_CONFIG_FILE_PATH, "w") as f:
dump(Params.CUT_PSG_CONFIG_NEW_CONTENT, f)
with open(Params.CUT_PSG_CONFIG_FILE_PATH, "r") as f:
file_config = load(f.read(), Loader=FullLoader)
Config.update(file_config)
# 数据回显
self.ui.spinBox_ECGFreq.setValue(Config["ECGFreq"])
def __slot_btn_execute__(self):
PublicFunc.__disableAllButton__(self, ButtonState)
self.data = Data(self.root_path, self.sampID)
Config["ECGFreq"] = self.ui.spinBox_ECGFreq.value()
# 检查文件是否存在并获取其数据采样率
PublicFunc.progressbar_update(self, 1, 5, Constants.CUT_PSG_GETTING_FILE_AND_FREQ, 0)
result = self.data.get_file_and_freq()
if not result.status:
PublicFunc.text_output(self.ui, "(1/5)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
PublicFunc.finish_operation(self, ButtonState)
return
else:
PublicFunc.text_output(self.ui, "(1/5)" + result.info, Constants.TIPS_TYPE_INFO)
PublicFunc.finish_operation(self, ButtonState)
# 导入数据
PublicFunc.progressbar_update(self, 2, 5, Constants.INPUTTING_DATA, 10)
result = self.data.open_file()
if not result.status:
PublicFunc.text_output(self.ui, "(2/5)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
PublicFunc.finish_operation(self, ButtonState)
return
else:
PublicFunc.text_output(self.ui, "(2/5)" + result.info, Constants.TIPS_TYPE_INFO)
PublicFunc.finish_operation(self, ButtonState)
# 切割数据
PublicFunc.progressbar_update(self, 3, 5, Constants.CUT_PSG_CUTTING_DATA, 40)
result = self.data.cut_data()
if not result.status:
PublicFunc.text_output(self.ui, "(3/5)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
PublicFunc.finish_operation(self, ButtonState)
return
else:
PublicFunc.text_output(self.ui, "(3/5)" + result.info, Constants.TIPS_TYPE_INFO)
PublicFunc.finish_operation(self, ButtonState)
# 标签映射
PublicFunc.progressbar_update(self, 4, 5, Constants.CUT_PSG_ALIGNING_LABEL, 60)
result = self.data.align_label()
if not result.status:
PublicFunc.text_output(self.ui, "(4/5)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
PublicFunc.finish_operation(self, ButtonState)
return
else:
PublicFunc.text_output(self.ui, "(4/5)" + result.info, Constants.TIPS_TYPE_INFO)
PublicFunc.finish_operation(self, ButtonState)
# 保存数据
PublicFunc.progressbar_update(self, 5, 5, Constants.SAVING_DATA, 70)
result = self.data.save()
if not result.status:
PublicFunc.text_output(self.ui, "(5/5)" + result.info, Constants.TIPS_TYPE_ERROR)
PublicFunc.msgbox_output(self, result.info, Constants.MSGBOX_TYPE_ERROR)
PublicFunc.finish_operation(self, ButtonState)
return
else:
PublicFunc.text_output(self.ui, "(5/5)" + result.info, Constants.TIPS_TYPE_INFO)
for key, raw in self.data.raw.items():
info = "保存{}的长度为{},采样率为{}Hz".format(key, str(len(raw)), str(self.data.freq[key]))
PublicFunc.text_output(self.ui, info, Constants.TIPS_TYPE_INFO)
QApplication.processEvents()
PublicFunc.msgbox_output(self, result.info, Constants.TIPS_TYPE_INFO)
PublicFunc.finish_operation(self, ButtonState)
class Data:
def __init__(self, root_path, sampID):
self.alignInfo = None
self.raw = {key: array([]) for key in Config["ChannelInput"]}
self.freq = {key: 0 for key in Config["ChannelInput"]}
self.SALabel = None
self.startTime = None
self.root_path = root_path
self.sampID = sampID
def get_file_and_freq(self):
try:
for file_path in Path(Config["Path"]["InputFolder"]).glob('*'):
if file_path.is_file():
file_stem = Path(file_path).stem
for key, prefix in Config["ChannelInput"].items():
if file_stem.startswith(prefix):
freq_str = file_stem.rsplit('_', 1)[1]
try:
freq = int(freq_str)
self.freq[key] = freq
except ValueError:
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["Filename_Format_not_Correct"])
for value in self.freq.values():
if value == 0:
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["Filename_Format_not_Correct"])
if not any((Config["LabelInput"]["SA Label"] + Config["EndWith"]["SA Label"]) in str(file) for file in Path(Config["Path"]["InputFolder"]).glob('*')):
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["File_not_Exist"])
if not any((Config["StartTime"] + Config["EndWith"]["StartTime"]) in str(file) for file in Path(Config["Path"]["InputFolder"]).glob('*')):
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["File_not_Exist"])
if not Path(Config["Path"]["InputAlignInfo"]).exists():
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["File_not_Exist"])
except Exception as e:
return Result().failure(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FAILURE +
Constants.FAILURE_REASON["Get_File_and_Freq_Excepetion"] + "\n" + format_exc())
return Result().success(info=Constants.CUT_PSG_GET_FILE_AND_FREQ_FINISHED)
def open_file(self):
path = str(Path(self.root_path) / Filename.PATH_PSG_TEXT / Path(str(self.sampID)))
for value in Config["ChannelInput"].values():
result = PublicFunc.examine_file(path, value, Params.ENDSWITH_TXT)
if not result.status:
return result
if Path(Config["Path"]["InputAlignInfo"]).is_file():
Config["Path"]["InputAlignInfo"] = str(Path(Config["Path"]["InputAlignInfo"]).parent)
Config["Path"]["InputAlignInfo"] = str(
Path(Config["Path"]["InputAlignInfo"]) / Path(
Filename.PRECISELY_ALIGN_INFO + Params.ENDSWITH_TXT))
try:
for key in Config["ChannelInput"].keys():
self.raw[key] = read_csv(Path(Config["Path"]["InputFolder"]) / Path((Config["ChannelInput"][key] + str(self.freq[key]) + Config["EndWith"][key])),
encoding=Params.UTF8_ENCODING,
header=None).to_numpy().reshape(-1)
self.SALabel = read_csv(Path(Config["Path"]["InputFolder"]) / Path((Config["LabelInput"]["SA Label"] + Config["EndWith"]["SA Label"])),
encoding=Params.GBK_ENCODING)
self.startTime = read_csv(Path(Config["Path"]["InputFolder"]) / Path((Config["StartTime"] + Config["EndWith"]["StartTime"])),
encoding=Params.UTF8_ENCODING,
header=None).to_numpy().reshape(-1)
self.alignInfo = read_csv(Path(Config["Path"]["InputAlignInfo"]),
encoding=Params.UTF8_ENCODING,
header=None).to_numpy().reshape(-1)
self.alignInfo = literal_eval(self.alignInfo[0])
except Exception as e:
return Result().failure(info=Constants.INPUT_FAILURE +
Constants.FAILURE_REASON["Open_Data_Exception"] + "\n" + format_exc())
return Result().success(info=Constants.INPUT_FINISHED)
def cut_data(self):
try:
for key, raw in self.raw.items():
# 转换切割点
ECG_freq = Config["ECGFreq"]
raw_freq = self.freq[key]
duration_second = ((self.alignInfo["cut_index"]["back_ECG"] - self.alignInfo["cut_index"]["front_ECG"]) // ECG_freq) + 1
start_index_cut = floor(self.alignInfo["cut_index"]["front_ECG"] * (raw_freq / ECG_freq))
end_index_cut = start_index_cut + (duration_second * raw_freq)
try:
# 切割信号
self.raw[key] = self.raw[key][start_index_cut:end_index_cut]
except Exception:
return Result().failure(info=Constants.CUT_PSG_CUT_DATA_FAILURE +
Constants.FAILURE_REASON["Cut_Data_Length_not_Correct"])
except Exception as e:
return Result().failure(info=Constants.CUT_PSG_CUT_DATA_FAILURE +
Constants.FAILURE_REASON["Cut_Data_Exception"] + "\n" + format_exc())
return Result().success(info=Constants.CUT_PSG_CUT_DATA_FINISHED)
def align_label(self):
try:
# 读取SA标签
self.SALabel = self.SALabel.loc[:, ~self.SALabel.columns.str.contains("^Unnamed")]
self.SALabel = self.SALabel[self.SALabel["Event type"].isin(Params.CUT_PSG_SALABEL_EVENT)]
self.SALabel["Duration"] = self.SALabel["Duration"].astype(str)
self.SALabel["Duration"] = self.SALabel["Duration"].str.replace(r' \(.*?\)', '', regex=True)
except Exception:
return Result().failure(info=Constants.CUT_PSG_ALIGN_LABEL_FAILURE +
Constants.FAILURE_REASON["Align_Label_SALabel_Format_not_Correct"])
try:
# 获取记录开始时间
start_time = str(self.startTime[0]).split(" ")[1]
start_time = Data.get_time_to_seconds(start_time)
# 计算起始时间秒数和终止时间秒数
self.SALabel["Start"] = (self.SALabel["Time"].apply(self.get_time_to_seconds) - start_time).apply(
lambda x: x + 24 * 3600 if x < 0 else x).astype(int)
self.SALabel["End"] = self.SALabel["Start"] + self.SALabel["Duration"].astype(float).round(0).astype(int)
# 标签映射
ECG_length = self.alignInfo["cut_index"]["back_ECG"] - self.alignInfo["cut_index"]["front_ECG"]
self.SALabel["Start"] = self.SALabel["Start"] - round((self.alignInfo["cut_index"]["front_ECG"] / 1000))
self.SALabel["End"] = self.SALabel["End"] - round((self.alignInfo["cut_index"]["front_ECG"] / 1000))
self.SALabel = self.SALabel[self.SALabel["End"] >= 0]
self.SALabel.loc[self.SALabel["Start"] < 0, "Start"] = 0
self.SALabel = self.SALabel[self.SALabel["Start"] < ECG_length]
self.SALabel.loc[self.SALabel["End"] >= ECG_length, "End"] = ECG_length - 1
except Exception as e:
return Result().failure(info=Constants.CUT_PSG_ALIGN_LABEL_FAILURE +
Constants.FAILURE_REASON["Align_Label_Exception"] + "\n" + format_exc())
return Result().success(info=Constants.CUT_PSG_ALIGN_LABEL_FINISHED)
def save(self):
for raw in self.raw.values():
if len(raw) == 0:
return Result().failure(info=Constants.SAVE_FAILURE +
Constants.FAILURE_REASON["Data_not_Exist"])
try:
for key, raw in self.raw.items():
DataFrame(raw.reshape(-1)).to_csv(Path(Config["Path"]["SaveFolder"]) / Path((Config["ChannelSave"][key] + str(self.freq[key]) + Config["EndWith"][key])),
index=False, header=False)
# 重排index从1开始并给index命名
self.SALabel.sort_values(by=["Start"], inplace=True)
self.SALabel.reset_index(drop=True, inplace=True)
self.SALabel.index = self.SALabel.index + 1
self.SALabel.index.name = "Index"
self.SALabel.to_csv(Path(Config["Path"]["SaveFolder"]) / Path((Config["LabelSave"]["SA Label"] + Config["EndWith"]["SA Label"])),
encoding="gbk")
except PermissionError as e:
return Result().failure(info=Constants.SAVE_FAILURE + Constants.FAILURE_REASON["Save_Permission_Denied"])
except FileNotFoundError as e:
return Result().failure(info=Constants.SAVE_FAILURE + Constants.FAILURE_REASON["Save_File_Not_Found"])
except Exception as e:
return Result().failure(info=Constants.SAVE_FAILURE +
Constants.FAILURE_REASON["Save_Exception"] + "\n" + format_exc())
return Result().success(info=Constants.SAVE_FINISHED)
@staticmethod
def get_time_to_seconds(time_str):
h, m, s = map(int, time_str.split(":"))
return h * 3600 + m * 60 + s