You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
8.4 KiB
169 lines
8.4 KiB
import torch |
|
import numpy as np |
|
from ner.Bert_BiLSTM_CRF import Bert_BiLSTM_CRF |
|
from common import config |
|
import os |
|
import logging.handlers |
|
from concurrent_log_handler import ConcurrentRotatingFileHandler |
|
import time |
|
|
|
try: |
|
log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|
formatter = logging.Formatter(log_fmt) |
|
handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8') |
|
handler.setFormatter(formatter) |
|
logging.basicConfig(level=logging.DEBUG) |
|
log = logging.getLogger(__name__) |
|
log.addHandler(handler) |
|
except Exception as error: |
|
result_dict = {"code": 500,"error_msg":"日志文件打开失败"} |
|
|
|
class panjueshu(): |
|
def __init__(self): |
|
try: |
|
time1=time.time() |
|
VOCAB_list = ['<PAD>', '[CLS]', '[SEP]', 'O'] |
|
self.ner_dict = {} |
|
self.tag_file_path = os.path.join(os.getcwd(),"ner","panjueshu","标签.txt") |
|
self.line_list=[] |
|
with open(self.tag_file_path,"r",encoding="utf-8") as f: |
|
self.line_list=f.readlines() |
|
self.line_len=len(self.line_list) |
|
for line_number in range(self.line_len): |
|
bq_str = self.line_list[line_number].replace("\n","") |
|
self.ner_dict["A" + str(line_number+1)] = bq_str |
|
VOCAB_list.append('B-' + "A" + str(line_number+1)) |
|
VOCAB_list.append('I-' + "A" + str(line_number+1)) |
|
self.VOCAB = tuple(VOCAB_list) |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.model_name="判决书" |
|
tag2idx = {tag: idx for idx, tag in enumerate(self.VOCAB)} |
|
time2 = time.time() |
|
print("ner panjueshu main __init__ time:{}s".format(time2 - time1)) |
|
self.model = Bert_BiLSTM_CRF(tag2idx).cuda() |
|
time3 = time.time() |
|
print("ner panjueshu main __init__ time:{}s".format(time3 - time2)) |
|
self.model.load_state_dict(torch.load(os.path.join(config.model_path, self.model_name,"model.pt"))) |
|
time4 = time.time() |
|
print("ner panjueshu main __init__ time:{}s".format(time4 - time3)) |
|
self.model.eval() |
|
time5 = time.time() |
|
print("ner panjueshu main __init__ time:{}s".format(time5 - time4)) |
|
self.bdjs_str_list = ["。", ";", ".", ";", ":", ":", ",", ","] |
|
except Exception as error: |
|
log.error("ner panjueshu main __init__ error:{}".format(error),exc_info=True) |
|
def deal_input_data(self,input_str_list): |
|
try: |
|
time1=time.time() |
|
output_str_list=[] |
|
vocab_list = self.load_vocab() |
|
all_words_idx_tensor_list = [] |
|
for input_str in input_str_list: |
|
split_input_str_list = [] |
|
while (len(input_str) > 510): |
|
split_start_index = 509 |
|
if (input_str[509] not in self.bdjs_str_list): |
|
split_start_index = 508 |
|
while (split_start_index >= 0): |
|
# print(input_str[split_start_index]) |
|
if (input_str[split_start_index] in self.bdjs_str_list): |
|
break |
|
split_start_index -= 1 |
|
split_input_str = input_str[0:split_start_index + 1] |
|
input_str = input_str[split_start_index + 1:len(input_str)] |
|
split_input_str_list.append(split_input_str) |
|
output_str_list.append(split_input_str) |
|
split_input_str_list.append(input_str) |
|
output_str_list.append(input_str) |
|
for split_input_str in split_input_str_list: |
|
cur_words_idx_list = [] |
|
words_idx_list = [] |
|
cur_words_idx_tensor = [] |
|
words_idx_list.append(101) |
|
for input_word in split_input_str: |
|
input_word = input_word.lower() |
|
if (input_word not in vocab_list): |
|
words_idx_list.append((144)) |
|
else: |
|
word_idx = vocab_list.index(input_word.lower()) |
|
words_idx_list.append((word_idx)) |
|
words_idx_list.append(102) |
|
cur_words_idx_list.append(words_idx_list) |
|
cur_words_idx_array = np.array(cur_words_idx_list) |
|
cur_words_idx_tensor = torch.LongTensor(cur_words_idx_array) |
|
all_words_idx_tensor_list.append(cur_words_idx_tensor) |
|
time2 = time.time() |
|
print("ner panjueshu main deal_input_data time:{}s".format(time2 - time1)) |
|
return all_words_idx_tensor_list,output_str_list |
|
except Exception as error: |
|
log.error("ner panjueshu main deal_input_data error:{}".format(error),exc_info=True) |
|
return [] |
|
def load_vocab(self): |
|
try: |
|
time1=time.time() |
|
vocab_list = [] |
|
with open(os.path.join(config.from_pretrained_path,"vocab.txt"), "r", encoding="utf-8") as reader: |
|
token_list = reader.readlines() |
|
token_list_len = len(token_list) |
|
for token_list_number in range(token_list_len): |
|
token = token_list[token_list_number] |
|
if not token: |
|
break |
|
token = token.strip() |
|
vocab_list.append(token) |
|
time2 = time.time() |
|
print("ner panjueshu main load_vocab time:{}s".format(time2 - time1)) |
|
return vocab_list |
|
except Exception as error: |
|
log.error("ner panjueshu main load_vocab error:{}".format(error),exc_info=True) |
|
return [] |
|
def predict(self,input_str_list): |
|
try: |
|
time1=time.time() |
|
result_ner_json_list=[] |
|
predict_data_list,input_str_list=self.deal_input_data(input_str_list) |
|
with torch.no_grad(): |
|
predict_datas_len = len(predict_data_list) |
|
for predict_datas_number in range(predict_datas_len): |
|
ner_json_list=[] |
|
predict_data = predict_data_list[predict_datas_number] |
|
Y_hat = [] |
|
x = predict_data |
|
x = x.to(self.device) |
|
# print("x".format(x)) |
|
# y = y.to(device) |
|
_, y_hat = self.model(x) # y_hat: (N, T) |
|
# print("_:{} y_hat:{}".format(_,y_hat)) |
|
Y_hat.extend(y_hat.cpu().numpy().tolist()) |
|
Y_hat_len = len(Y_hat) |
|
for Y_hat_number in range(Y_hat_len): |
|
y_hat = Y_hat[Y_hat_number] |
|
input_str = input_str_list[predict_datas_number] |
|
y_hat_len = len(y_hat) |
|
y_hat_number = 0 |
|
#print("原文:{}".format(input_str)) |
|
while (y_hat_number < y_hat_len): |
|
y = y_hat[y_hat_number] |
|
if (y <= 3): |
|
y_hat_number += 1 |
|
continue |
|
else: |
|
ner_eng_name = self.VOCAB[y] |
|
ner_eng_name = ner_eng_name.split("-")[1] |
|
ner_ch_name = self.ner_dict[ner_eng_name] |
|
ner_words = input_str[y_hat_number - 1] |
|
y_hat_number += 1 |
|
while (self.VOCAB[y_hat[y_hat_number]] == "I-" + ner_eng_name): |
|
ner_words += input_str[y_hat_number - 1] |
|
y_hat_number += 1 |
|
ner_json={'tag':ner_ch_name,'term':ner_words} |
|
ner_json_list.append(ner_json) |
|
#print("{}:{}".format(ner_ch_name, ner_words)) |
|
continue |
|
result_ner_json_list.append(ner_json_list) |
|
time2 = time.time() |
|
print("ner panjueshu main predict time:{}s".format(time2 - time1)) |
|
return {'code':200,'data':result_ner_json_list,'msg':""} |
|
except Exception as error: |
|
log.error("ner panjueshu main predict error:{}".format(error),exc_info=True) |
|
return {'code': 500, 'data': [], 'msg': error} |