import torch import numpy as np from ner.Bert_BiLSTM_CRF import Bert_BiLSTM_CRF from common import config import os import logging.handlers from concurrent_log_handler import ConcurrentRotatingFileHandler import time try: log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' formatter = logging.Formatter(log_fmt) handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8') handler.setFormatter(formatter) logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__) log.addHandler(handler) except Exception as error: result_dict = {"code": 500,"error_msg":"日志文件打开失败"} class panjueshu(): def __init__(self): try: time1=time.time() VOCAB_list = ['', '[CLS]', '[SEP]', 'O'] self.ner_dict = {} self.tag_file_path = os.path.join(os.getcwd(),"ner","panjueshu","标签.txt") self.line_list=[] with open(self.tag_file_path,"r",encoding="utf-8") as f: self.line_list=f.readlines() self.line_len=len(self.line_list) for line_number in range(self.line_len): bq_str = self.line_list[line_number].replace("\n","") self.ner_dict["A" + str(line_number+1)] = bq_str VOCAB_list.append('B-' + "A" + str(line_number+1)) VOCAB_list.append('I-' + "A" + str(line_number+1)) self.VOCAB = tuple(VOCAB_list) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model_name="判决书" tag2idx = {tag: idx for idx, tag in enumerate(self.VOCAB)} time2 = time.time() print("ner panjueshu main __init__ time:{}s".format(time2 - time1)) self.model = Bert_BiLSTM_CRF(tag2idx).cuda() time3 = time.time() print("ner panjueshu main __init__ time:{}s".format(time3 - time2)) self.model.load_state_dict(torch.load(os.path.join(config.model_path, self.model_name,"model.pt"))) time4 = time.time() print("ner panjueshu main __init__ time:{}s".format(time4 - time3)) self.model.eval() time5 = time.time() print("ner panjueshu main __init__ time:{}s".format(time5 - time4)) self.bdjs_str_list = ["。", ";", ".", ";", ":", ":", ",", ","] except Exception as error: log.error("ner panjueshu main __init__ error:{}".format(error),exc_info=True) def deal_input_data(self,input_str_list): try: time1=time.time() output_str_list=[] vocab_list = self.load_vocab() all_words_idx_tensor_list = [] for input_str in input_str_list: split_input_str_list = [] while (len(input_str) > 510): split_start_index = 509 if (input_str[509] not in self.bdjs_str_list): split_start_index = 508 while (split_start_index >= 0): # print(input_str[split_start_index]) if (input_str[split_start_index] in self.bdjs_str_list): break split_start_index -= 1 split_input_str = input_str[0:split_start_index + 1] input_str = input_str[split_start_index + 1:len(input_str)] split_input_str_list.append(split_input_str) output_str_list.append(split_input_str) split_input_str_list.append(input_str) output_str_list.append(input_str) for split_input_str in split_input_str_list: cur_words_idx_list = [] words_idx_list = [] cur_words_idx_tensor = [] words_idx_list.append(101) for input_word in split_input_str: input_word = input_word.lower() if (input_word not in vocab_list): words_idx_list.append((144)) else: word_idx = vocab_list.index(input_word.lower()) words_idx_list.append((word_idx)) words_idx_list.append(102) cur_words_idx_list.append(words_idx_list) cur_words_idx_array = np.array(cur_words_idx_list) cur_words_idx_tensor = torch.LongTensor(cur_words_idx_array) all_words_idx_tensor_list.append(cur_words_idx_tensor) time2 = time.time() print("ner panjueshu main deal_input_data time:{}s".format(time2 - time1)) return all_words_idx_tensor_list,output_str_list except Exception as error: log.error("ner panjueshu main deal_input_data error:{}".format(error),exc_info=True) return [] def load_vocab(self): try: time1=time.time() vocab_list = [] with open(os.path.join(config.from_pretrained_path,"vocab.txt"), "r", encoding="utf-8") as reader: token_list = reader.readlines() token_list_len = len(token_list) for token_list_number in range(token_list_len): token = token_list[token_list_number] if not token: break token = token.strip() vocab_list.append(token) time2 = time.time() print("ner panjueshu main load_vocab time:{}s".format(time2 - time1)) return vocab_list except Exception as error: log.error("ner panjueshu main load_vocab error:{}".format(error),exc_info=True) return [] def predict(self,input_str_list): try: time1=time.time() result_ner_json_list=[] predict_data_list,input_str_list=self.deal_input_data(input_str_list) with torch.no_grad(): predict_datas_len = len(predict_data_list) for predict_datas_number in range(predict_datas_len): ner_json_list=[] predict_data = predict_data_list[predict_datas_number] Y_hat = [] x = predict_data x = x.to(self.device) # print("x".format(x)) # y = y.to(device) _, y_hat = self.model(x) # y_hat: (N, T) # print("_:{} y_hat:{}".format(_,y_hat)) Y_hat.extend(y_hat.cpu().numpy().tolist()) Y_hat_len = len(Y_hat) for Y_hat_number in range(Y_hat_len): y_hat = Y_hat[Y_hat_number] input_str = input_str_list[predict_datas_number] y_hat_len = len(y_hat) y_hat_number = 0 #print("原文:{}".format(input_str)) while (y_hat_number < y_hat_len): y = y_hat[y_hat_number] if (y <= 3): y_hat_number += 1 continue else: ner_eng_name = self.VOCAB[y] ner_eng_name = ner_eng_name.split("-")[1] ner_ch_name = self.ner_dict[ner_eng_name] ner_words = input_str[y_hat_number - 1] y_hat_number += 1 while (self.VOCAB[y_hat[y_hat_number]] == "I-" + ner_eng_name): ner_words += input_str[y_hat_number - 1] y_hat_number += 1 ner_json={'tag':ner_ch_name,'term':ner_words} ner_json_list.append(ner_json) #print("{}:{}".format(ner_ch_name, ner_words)) continue result_ner_json_list.append(ner_json_list) time2 = time.time() print("ner panjueshu main predict time:{}s".format(time2 - time1)) return {'code':200,'data':result_ner_json_list,'msg':""} except Exception as error: log.error("ner panjueshu main predict error:{}".format(error),exc_info=True) return {'code': 500, 'data': [], 'msg': error}