Initial Commit.

This commit is contained in:
Yorusora
2025-04-28 11:33:05 +08:00
parent f0eb8083b1
commit f928fa4d9c
28 changed files with 7126 additions and 0 deletions

295
func/utils/detect_Jpeak.py Normal file
View File

@ -0,0 +1,295 @@
from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan
from numpy import min as np_min
from numpy import max as np_max
from torch import FloatTensor, no_grad, load
from torch import device as torch_device
from torch.cuda import is_available, empty_cache
from torch.nn.functional import sigmoid
from func.BCGDataset import BCG_Operation
from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet
def evaluate(test_data, model,fs,useCPU):
orgBCG = test_data
operation = BCG_Operation()
# 降采样
orgBCG = operation.down_sample(orgBCG, down_radio=int(fs//100)).copy() #一开始没加.copy()会报错,后来加了就没事了,结果没影响
# plt.figure()
# plt.plot(orgBCG)
# plt.show()
orgBCG = orgBCG.reshape(-1, 1000)
# test dataset
orgData = FloatTensor(orgBCG).unsqueeze(1)
# predict
if useCPU == True:
gpu = False
device = torch_device("cpu")
else:
gpu = is_available()
device = torch_device("cuda" if is_available() else "cpu")
# if gpu:
# orgData = orgData.cuda()
# model.cuda()
orgData = orgData.to(device)
model = model.to(device)
with no_grad():
y_hat = model(orgData)
y_prob = sigmoid(y_hat)
beat = (y_prob>0.5).float().view(-1).cpu().data.numpy()
beat_diff = diff(beat)
up_index = argwhere(beat_diff==1)
down_index = argwhere(beat_diff==-1)
return beat,up_index,down_index,y_prob
def find_TPeak(data,peaks,th=50):
"""
找出真实的J峰或R峰
:param data: BCG或ECG数据
:param peaks: 初步峰值从label中导出的location_R
:param th: 范围阈值
:return: 真实峰值
"""
return_peak = []
for peak in peaks:
if peak>len(data):continue
min_win,max_win = max(0,int(peak-th)),min(len(data),int(peak+th))
return_peak.append(argmax(data[min_win:max_win])+min_win)
return return_peak
def new_calculate_beat(y,predict,th=0.5,up=10,th1=100,th2=45): #通过预测计算回原来J峰的坐标 输入y_prob,predict=ture,up*10,降采样多少就乘多少
"""
加上不应期算法,消除误判的峰
:param y: 预测输出值或者标签值label
:param predict: ture or false
:param up: 降采样为多少就多少
:return: 预测的J峰位置
"""
if predict:
beat = where(y>th,1,0)
else:
beat = y
beat_diff = diff(beat) #一阶差分
up_index = argwhere(beat_diff == 1).reshape(-1)
down_index = argwhere(beat_diff == -1).reshape(-1)
# print(up_index,down_index)
# print(y)
# print(y[up_index[4]+1:down_index[4]+1])
if len(up_index)==0:
return [0]
if up_index[0] > down_index[0]:
down_index = delete(down_index, 0)
if up_index[-1] > down_index[-1]:
up_index = delete(up_index, -1)
"""
加上若大于130点都没有一个心跳时降低阈值重新判决一次一般降到0.3就可以了;; 但是对于体动片段降低阈值可能又会造成误判,而且出现体动的话会被丢弃,间隔时间也长
"""
# print("初始:",up_index.shape,down_index.shape)
i = 0
lenth1 = len(up_index)
while i < len(up_index)-1:
if abs(up_index[i+1]-up_index[i]) > th1:
re_prob = y[down_index[i]+15:up_index[i+1]-15] #原本按正常应该是两个都+1的但是由于Unet输出低于0.6时把阈值调小后会在附近一两个点也变为1会影响判断
# print(re_prob.shape)
beat1 = where(re_prob > 0.1, 1, 0)
# print(beat1)
if sum(beat1) != 0 and beat1[0] != 1 and beat1[-1] != 1:
insert_up_index,insert_down_index = add_beat(re_prob,th=0.1)
# print(insert_up_index,insert_down_index,i)
if len(insert_up_index) > 1:
l = i+1
for u,d in zip(insert_up_index,insert_down_index):
up_index = insert(up_index,l,u+down_index[i]+1+15) #np.insert(arr, obj, values, axis) arr原始数组可一可多obj插入元素位置values是插入内容axis是按行按列插入。
down_index = insert(down_index,l,d+down_index[i]+1+15)
l = l+1
# print('l=', l)
elif len(insert_up_index) == 1:
# print(i)
up_index = insert(up_index,i+1,down_index[i]+insert_up_index+1+15)
down_index = insert(down_index,i+1,down_index[i]+insert_down_index+1+15)
i = i + len(insert_up_index) + 1
else:
i = i+1
continue
else:
i = i+1
# print("最终:",up_index.shape,down_index.shape)
"""
添加不应期
"""
new_up_index = up_index
new_down_index = down_index
flag = 0
i = 0
lenth = len(up_index)
while i < lenth:
if abs(up_index[i+1]-up_index[i]) < th2:
prob_forward = y[up_index[i]+1:down_index[i]+1]
prob_backward = y[up_index[i+1]+1:down_index[i+1]+1]
forward_score = 0
back_score = 0
forward_count = down_index[i] - up_index[i]
back_count = down_index[i+1] - up_index[i+1]
forward_max = np_max(prob_forward)
back_max = np_max(prob_backward)
forward_min = np_min(prob_forward)
back_min = np_min(prob_backward)
forward_average = mean(prob_forward)
back_average = mean(prob_backward)
if forward_count > back_count:
forward_score = forward_score + 1
else:back_score = back_score + 1
if forward_max > back_max:
forward_score = forward_score + 1
else:back_score = back_score + 1
if forward_min < back_min:
forward_score = forward_score + 1
else:back_score = back_score + 1
if forward_average > back_average:
forward_score = forward_score + 1
else:back_score = back_score + 1
if forward_score >=3:
up_index = delete(up_index, i+1)
down_index = delete(down_index, i+1)
flag = 1
elif back_score >=3:
up_index = delete(up_index, i)
down_index = delete(down_index, i)
flag = 1
elif forward_score == back_score:
if forward_average > back_average:
up_index = delete(up_index, i + 1)
down_index = delete(down_index, i + 1)
flag = 1
else:
up_index = delete(up_index, i)
down_index = delete(down_index, i)
flag = 1
if flag == 1:
i = i
flag = 0
else: i = i+1
else:i = i + 1
if i > len(up_index)-2:
break
# elif abs(up_index[i+1]-up_index[i]) > 120:
# print("全部处理之后",up_index.shape,down_index.shape)
predict_J = (up_index.reshape(-1) + down_index.reshape(-1)) // 2*up
# predict_J = predict_J.astype(int)
return predict_J
def add_beat(y,th=0.2): #通过预测计算回原来J峰的坐标 输入y_prob,predict=ture,up*10,降采样多少就乘多少
"""
:param y: 预测输出值或者标签值label
:param predict: ture or false
:param up: 降采样为多少就多少
:return: 预测的J峰位置
"""
beat1 = where(y>th,1,0)
beat_diff1 = diff(beat1) #一阶差分
add_up_index = argwhere(beat_diff1 == 1).reshape(-1)
add_down_index = argwhere(beat_diff1 == -1).reshape(-1)
# print(beat1)
# print(add_up_index,add_down_index)
if len(add_up_index) > 0:
if add_up_index[0] > add_down_index[0]:
add_down_index = delete(add_down_index, 0)
if add_up_index[-1] > add_down_index[-1]:
add_up_index = delete(add_up_index, -1)
return add_up_index, add_down_index
else:
return 0
def calculate_beat(y,predict,th=0.5,up=10): #通过预测计算回原来J峰的坐标 输入y_prob,predict=ture,up*10,降采样多少就乘多少
"""
:param y: 预测输出值或者标签值label
:param predict: ture or false
:param up: 降采样为多少就多少
:return: 预测的J峰位置
"""
if predict:
beat = where(y>th,1,0)
else:
beat = y
beat_diff = diff(beat) #一阶差分
up_index = argwhere(beat_diff == 1).reshape(-1)
down_index = argwhere(beat_diff == -1).reshape(-1)
if len(up_index)==0:
return [0]
if up_index[0] > down_index[0]:
down_index = delete(down_index, 0)
if up_index[-1] > down_index[-1]:
up_index = delete(up_index, -1)
predict_J = (up_index.reshape(-1) + down_index.reshape(-1)) // 2*up
# predict_J = predict_J.astype(int)
return predict_J
def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value):
bcg_data = raw_bcg[:len(raw_bcg) // (fs * 10) * fs * 10]
preprocessing = BCG_Operation(sample_rate=fs)
bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value
return bcg
def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU):
model_name = get_model_name(str(model_name))
if model_name == "Fivelayer_Unet":
model = Fivelayer_Unet()
elif model_name == "Fivelayer_Lstm_Unet":
model = Fivelayer_Lstm_Unet()
elif model_name == "Sixlayer_Unet":
model = Sixlayer_Unet()
elif model_name == "U_net":
model = Unet()
else:
raise Exception
model.load_state_dict(load(model_path, map_location=torch_device('cpu')))
model.eval()
# J峰预测
beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU)
y_prob = y_prob.cpu().reshape(-1).data.numpy()
predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low)
predict_J = find_TPeak(bcg_data, predict_J, th=int(peaks_value * fs / 1000))
predict_J = array(predict_J)
Interval = full(len(bcg_data), nan)
for i in range(len(predict_J) - 1):
Interval[predict_J[i]: predict_J[i + 1]] = predict_J[i + 1] - predict_J[i]
empty_cache()
return predict_J, Interval
def get_model_name(input_string):
# 找到最后一个 "_" 的位置
last_underscore_index = input_string.rfind('_')
# 如果没有找到 "_"
if last_underscore_index == -1:
return input_string # 返回整个字符串
# 返回最后一个 "_" 之前的部分
return input_string[:last_underscore_index]