from numpy import diff, argwhere, argmax, where, delete, insert, mean, array, full, nan from numpy import min as np_min from numpy import max as np_max from torch import FloatTensor, no_grad, load from torch import device as torch_device from torch.cuda import is_available, empty_cache from torch.nn.functional import sigmoid from func.BCGDataset import BCG_Operation from func.Deep_Model import Unet,Fivelayer_Lstm_Unet,Fivelayer_Unet,Sixlayer_Unet def evaluate(test_data, model,fs,useCPU): orgBCG = test_data operation = BCG_Operation() # 降采样 orgBCG = operation.down_sample(orgBCG, down_radio=int(fs//100)).copy() #一开始没加.copy()会报错,后来加了就没事了,结果没影响 # plt.figure() # plt.plot(orgBCG) # plt.show() orgBCG = orgBCG.reshape(-1, 1000) # test dataset orgData = FloatTensor(orgBCG).unsqueeze(1) # predict if useCPU == True: gpu = False device = torch_device("cpu") else: gpu = is_available() device = torch_device("cuda" if is_available() else "cpu") # if gpu: # orgData = orgData.cuda() # model.cuda() orgData = orgData.to(device) model = model.to(device) with no_grad(): y_hat = model(orgData) y_prob = sigmoid(y_hat) beat = (y_prob>0.5).float().view(-1).cpu().data.numpy() beat_diff = diff(beat) up_index = argwhere(beat_diff==1) down_index = argwhere(beat_diff==-1) return beat,up_index,down_index,y_prob def find_TPeak(data,peaks,th=50): """ 找出真实的J峰或R峰 :param data: BCG或ECG数据 :param peaks: 初步峰值(从label中导出的location_R) :param th: 范围阈值 :return: 真实峰值 """ return_peak = [] for peak in peaks: if peak>len(data):continue min_win,max_win = max(0,int(peak-th)),min(len(data),int(peak+th)) return_peak.append(argmax(data[min_win:max_win])+min_win) return return_peak def new_calculate_beat(y,predict,th=0.5,up=10,th1=100,th2=45): #通过预测计算回原来J峰的坐标 输入:y_prob,predict=ture,up*10,降采样多少就乘多少 """ 加上不应期算法,消除误判的峰 :param y: 预测输出值或者标签值(label) :param predict: ture or false :param up: 降采样为多少就多少 :return: 预测的J峰位置 """ if predict: beat = where(y>th,1,0) else: beat = y beat_diff = diff(beat) #一阶差分 up_index = argwhere(beat_diff == 1).reshape(-1) down_index = argwhere(beat_diff == -1).reshape(-1) # print(up_index,down_index) # print(y) # print(y[up_index[4]+1:down_index[4]+1]) if len(up_index)==0: return [0] if up_index[0] > down_index[0]: down_index = delete(down_index, 0) if up_index[-1] > down_index[-1]: up_index = delete(up_index, -1) """ 加上若大于130点都没有一个心跳时,降低阈值重新判决一次,一般降到0.3就可以了;; 但是对于体动片段降低阈值可能又会造成误判,而且出现体动的话会被丢弃,间隔时间也长 """ # print("初始:",up_index.shape,down_index.shape) i = 0 lenth1 = len(up_index) while i < len(up_index)-1: if abs(up_index[i+1]-up_index[i]) > th1: re_prob = y[down_index[i]+15:up_index[i+1]-15] #原本按正常应该是两个都+1的,但是由于Unet输出低于0.6时,把阈值调小后会在附近一两个点也变为1,会影响判断 # print(re_prob.shape) beat1 = where(re_prob > 0.1, 1, 0) # print(beat1) if sum(beat1) != 0 and beat1[0] != 1 and beat1[-1] != 1: insert_up_index,insert_down_index = add_beat(re_prob,th=0.1) # print(insert_up_index,insert_down_index,i) if len(insert_up_index) > 1: l = i+1 for u,d in zip(insert_up_index,insert_down_index): up_index = insert(up_index,l,u+down_index[i]+1+15) #np.insert(arr, obj, values, axis) arr原始数组,可一可多,obj插入元素位置,values是插入内容,axis是按行按列插入。 down_index = insert(down_index,l,d+down_index[i]+1+15) l = l+1 # print('l=', l) elif len(insert_up_index) == 1: # print(i) up_index = insert(up_index,i+1,down_index[i]+insert_up_index+1+15) down_index = insert(down_index,i+1,down_index[i]+insert_down_index+1+15) i = i + len(insert_up_index) + 1 else: i = i+1 continue else: i = i+1 # print("最终:",up_index.shape,down_index.shape) """ 添加不应期 """ new_up_index = up_index new_down_index = down_index flag = 0 i = 0 lenth = len(up_index) while i < lenth: if abs(up_index[i+1]-up_index[i]) < th2: prob_forward = y[up_index[i]+1:down_index[i]+1] prob_backward = y[up_index[i+1]+1:down_index[i+1]+1] forward_score = 0 back_score = 0 forward_count = down_index[i] - up_index[i] back_count = down_index[i+1] - up_index[i+1] forward_max = np_max(prob_forward) back_max = np_max(prob_backward) forward_min = np_min(prob_forward) back_min = np_min(prob_backward) forward_average = mean(prob_forward) back_average = mean(prob_backward) if forward_count > back_count: forward_score = forward_score + 1 else:back_score = back_score + 1 if forward_max > back_max: forward_score = forward_score + 1 else:back_score = back_score + 1 if forward_min < back_min: forward_score = forward_score + 1 else:back_score = back_score + 1 if forward_average > back_average: forward_score = forward_score + 1 else:back_score = back_score + 1 if forward_score >=3: up_index = delete(up_index, i+1) down_index = delete(down_index, i+1) flag = 1 elif back_score >=3: up_index = delete(up_index, i) down_index = delete(down_index, i) flag = 1 elif forward_score == back_score: if forward_average > back_average: up_index = delete(up_index, i + 1) down_index = delete(down_index, i + 1) flag = 1 else: up_index = delete(up_index, i) down_index = delete(down_index, i) flag = 1 if flag == 1: i = i flag = 0 else: i = i+1 else:i = i + 1 if i > len(up_index)-2: break # elif abs(up_index[i+1]-up_index[i]) > 120: # print("全部处理之后",up_index.shape,down_index.shape) predict_J = (up_index.reshape(-1) + down_index.reshape(-1)) // 2*up # predict_J = predict_J.astype(int) return predict_J def add_beat(y,th=0.2): #通过预测计算回原来J峰的坐标 输入:y_prob,predict=ture,up*10,降采样多少就乘多少 """ :param y: 预测输出值或者标签值(label) :param predict: ture or false :param up: 降采样为多少就多少 :return: 预测的J峰位置 """ beat1 = where(y>th,1,0) beat_diff1 = diff(beat1) #一阶差分 add_up_index = argwhere(beat_diff1 == 1).reshape(-1) add_down_index = argwhere(beat_diff1 == -1).reshape(-1) # print(beat1) # print(add_up_index,add_down_index) if len(add_up_index) > 0: if add_up_index[0] > add_down_index[0]: add_down_index = delete(add_down_index, 0) if add_up_index[-1] > add_down_index[-1]: add_up_index = delete(add_up_index, -1) return add_up_index, add_down_index else: return 0 def calculate_beat(y,predict,th=0.5,up=10): #通过预测计算回原来J峰的坐标 输入:y_prob,predict=ture,up*10,降采样多少就乘多少 """ :param y: 预测输出值或者标签值(label) :param predict: ture or false :param up: 降采样为多少就多少 :return: 预测的J峰位置 """ if predict: beat = where(y>th,1,0) else: beat = y beat_diff = diff(beat) #一阶差分 up_index = argwhere(beat_diff == 1).reshape(-1) down_index = argwhere(beat_diff == -1).reshape(-1) if len(up_index)==0: return [0] if up_index[0] > down_index[0]: down_index = delete(down_index, 0) if up_index[-1] > down_index[-1]: up_index = delete(up_index, -1) predict_J = (up_index.reshape(-1) + down_index.reshape(-1)) // 2*up # predict_J = predict_J.astype(int) return predict_J def preprocess(raw_bcg, fs, low_cut, high_cut, amp_value): bcg_data = raw_bcg[:len(raw_bcg) // (fs * 10) * fs * 10] preprocessing = BCG_Operation(sample_rate=fs) bcg = preprocessing.Butterworth(bcg_data, "bandpass", low_cut=low_cut, high_cut=high_cut, order=3) * amp_value return bcg def Jpeak_Detection(model_name, model_path, bcg_data, fs, interval_high, interval_low, peaks_value, useCPU): model_name = get_model_name(str(model_name)) if model_name == "Fivelayer_Unet": model = Fivelayer_Unet() elif model_name == "Fivelayer_Lstm_Unet": model = Fivelayer_Lstm_Unet() elif model_name == "Sixlayer_Unet": model = Sixlayer_Unet() elif model_name == "U_net": model = Unet() else: raise Exception model.load_state_dict(load(model_path, map_location=torch_device('cpu'))) model.eval() # J峰预测 beat, up_index, down_index, y_prob = evaluate(bcg_data, model=model, fs=fs, useCPU=useCPU) y_prob = y_prob.cpu().reshape(-1).data.numpy() predict_J = new_calculate_beat(y_prob, 1, th=0.6, up=fs // 100, th1=interval_high, th2=interval_low) predict_J = find_TPeak(bcg_data, predict_J, th=int(peaks_value * fs / 1000)) predict_J = array(predict_J) Interval = full(len(bcg_data), nan) for i in range(len(predict_J) - 1): Interval[predict_J[i]: predict_J[i + 1]] = predict_J[i + 1] - predict_J[i] empty_cache() return predict_J, Interval def get_model_name(input_string): # 找到最后一个 "_" 的位置 last_underscore_index = input_string.rfind('_') # 如果没有找到 "_" if last_underscore_index == -1: return input_string # 返回整个字符串 # 返回最后一个 "_" 之前的部分 return input_string[:last_underscore_index]