You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
8.4 KiB
169 lines
8.4 KiB
1 year ago
|
import torch
|
||
|
import numpy as np
|
||
|
from ner.Bert_BiLSTM_CRF import Bert_BiLSTM_CRF
|
||
|
from common import config
|
||
|
import os
|
||
|
import logging.handlers
|
||
|
from concurrent_log_handler import ConcurrentRotatingFileHandler
|
||
|
import time
|
||
|
|
||
|
try:
|
||
|
log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||
|
formatter = logging.Formatter(log_fmt)
|
||
|
handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8')
|
||
|
handler.setFormatter(formatter)
|
||
|
logging.basicConfig(level=logging.DEBUG)
|
||
|
log = logging.getLogger(__name__)
|
||
|
log.addHandler(handler)
|
||
|
except Exception as error:
|
||
|
result_dict = {"code": 500,"error_msg":"日志文件打开失败"}
|
||
|
|
||
|
class panjueshu():
|
||
|
def __init__(self):
|
||
|
try:
|
||
|
time1=time.time()
|
||
|
VOCAB_list = ['<PAD>', '[CLS]', '[SEP]', 'O']
|
||
|
self.ner_dict = {}
|
||
|
self.tag_file_path = os.path.join(os.getcwd(),"ner","panjueshu","标签.txt")
|
||
|
self.line_list=[]
|
||
|
with open(self.tag_file_path,"r",encoding="utf-8") as f:
|
||
|
self.line_list=f.readlines()
|
||
|
self.line_len=len(self.line_list)
|
||
|
for line_number in range(self.line_len):
|
||
|
bq_str = self.line_list[line_number].replace("\n","")
|
||
|
self.ner_dict["A" + str(line_number+1)] = bq_str
|
||
|
VOCAB_list.append('B-' + "A" + str(line_number+1))
|
||
|
VOCAB_list.append('I-' + "A" + str(line_number+1))
|
||
|
self.VOCAB = tuple(VOCAB_list)
|
||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
self.model_name="判决书"
|
||
|
tag2idx = {tag: idx for idx, tag in enumerate(self.VOCAB)}
|
||
|
time2 = time.time()
|
||
|
print("ner panjueshu main __init__ time:{}s".format(time2 - time1))
|
||
|
self.model = Bert_BiLSTM_CRF(tag2idx).cuda()
|
||
|
time3 = time.time()
|
||
|
print("ner panjueshu main __init__ time:{}s".format(time3 - time2))
|
||
|
self.model.load_state_dict(torch.load(os.path.join(config.model_path, self.model_name,"model.pt")))
|
||
|
time4 = time.time()
|
||
|
print("ner panjueshu main __init__ time:{}s".format(time4 - time3))
|
||
|
self.model.eval()
|
||
|
time5 = time.time()
|
||
|
print("ner panjueshu main __init__ time:{}s".format(time5 - time4))
|
||
|
self.bdjs_str_list = ["。", ";", ".", ";", ":", ":", ",", ","]
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu main __init__ error:{}".format(error),exc_info=True)
|
||
|
def deal_input_data(self,input_str_list):
|
||
|
try:
|
||
|
time1=time.time()
|
||
|
output_str_list=[]
|
||
|
vocab_list = self.load_vocab()
|
||
|
all_words_idx_tensor_list = []
|
||
|
for input_str in input_str_list:
|
||
|
split_input_str_list = []
|
||
|
while (len(input_str) > 510):
|
||
|
split_start_index = 509
|
||
|
if (input_str[509] not in self.bdjs_str_list):
|
||
|
split_start_index = 508
|
||
|
while (split_start_index >= 0):
|
||
|
# print(input_str[split_start_index])
|
||
|
if (input_str[split_start_index] in self.bdjs_str_list):
|
||
|
break
|
||
|
split_start_index -= 1
|
||
|
split_input_str = input_str[0:split_start_index + 1]
|
||
|
input_str = input_str[split_start_index + 1:len(input_str)]
|
||
|
split_input_str_list.append(split_input_str)
|
||
|
output_str_list.append(split_input_str)
|
||
|
split_input_str_list.append(input_str)
|
||
|
output_str_list.append(input_str)
|
||
|
for split_input_str in split_input_str_list:
|
||
|
cur_words_idx_list = []
|
||
|
words_idx_list = []
|
||
|
cur_words_idx_tensor = []
|
||
|
words_idx_list.append(101)
|
||
|
for input_word in split_input_str:
|
||
|
input_word = input_word.lower()
|
||
|
if (input_word not in vocab_list):
|
||
|
words_idx_list.append((144))
|
||
|
else:
|
||
|
word_idx = vocab_list.index(input_word.lower())
|
||
|
words_idx_list.append((word_idx))
|
||
|
words_idx_list.append(102)
|
||
|
cur_words_idx_list.append(words_idx_list)
|
||
|
cur_words_idx_array = np.array(cur_words_idx_list)
|
||
|
cur_words_idx_tensor = torch.LongTensor(cur_words_idx_array)
|
||
|
all_words_idx_tensor_list.append(cur_words_idx_tensor)
|
||
|
time2 = time.time()
|
||
|
print("ner panjueshu main deal_input_data time:{}s".format(time2 - time1))
|
||
|
return all_words_idx_tensor_list,output_str_list
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu main deal_input_data error:{}".format(error),exc_info=True)
|
||
|
return []
|
||
|
def load_vocab(self):
|
||
|
try:
|
||
|
time1=time.time()
|
||
|
vocab_list = []
|
||
|
with open(os.path.join(config.from_pretrained_path,"vocab.txt"), "r", encoding="utf-8") as reader:
|
||
|
token_list = reader.readlines()
|
||
|
token_list_len = len(token_list)
|
||
|
for token_list_number in range(token_list_len):
|
||
|
token = token_list[token_list_number]
|
||
|
if not token:
|
||
|
break
|
||
|
token = token.strip()
|
||
|
vocab_list.append(token)
|
||
|
time2 = time.time()
|
||
|
print("ner panjueshu main load_vocab time:{}s".format(time2 - time1))
|
||
|
return vocab_list
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu main load_vocab error:{}".format(error),exc_info=True)
|
||
|
return []
|
||
|
def predict(self,input_str_list):
|
||
|
try:
|
||
|
time1=time.time()
|
||
|
result_ner_json_list=[]
|
||
|
predict_data_list,input_str_list=self.deal_input_data(input_str_list)
|
||
|
with torch.no_grad():
|
||
|
predict_datas_len = len(predict_data_list)
|
||
|
for predict_datas_number in range(predict_datas_len):
|
||
|
ner_json_list=[]
|
||
|
predict_data = predict_data_list[predict_datas_number]
|
||
|
Y_hat = []
|
||
|
x = predict_data
|
||
|
x = x.to(self.device)
|
||
|
# print("x".format(x))
|
||
|
# y = y.to(device)
|
||
|
_, y_hat = self.model(x) # y_hat: (N, T)
|
||
|
# print("_:{} y_hat:{}".format(_,y_hat))
|
||
|
Y_hat.extend(y_hat.cpu().numpy().tolist())
|
||
|
Y_hat_len = len(Y_hat)
|
||
|
for Y_hat_number in range(Y_hat_len):
|
||
|
y_hat = Y_hat[Y_hat_number]
|
||
|
input_str = input_str_list[predict_datas_number]
|
||
|
y_hat_len = len(y_hat)
|
||
|
y_hat_number = 0
|
||
|
#print("原文:{}".format(input_str))
|
||
|
while (y_hat_number < y_hat_len):
|
||
|
y = y_hat[y_hat_number]
|
||
|
if (y <= 3):
|
||
|
y_hat_number += 1
|
||
|
continue
|
||
|
else:
|
||
|
ner_eng_name = self.VOCAB[y]
|
||
|
ner_eng_name = ner_eng_name.split("-")[1]
|
||
|
ner_ch_name = self.ner_dict[ner_eng_name]
|
||
|
ner_words = input_str[y_hat_number - 1]
|
||
|
y_hat_number += 1
|
||
|
while (self.VOCAB[y_hat[y_hat_number]] == "I-" + ner_eng_name):
|
||
|
ner_words += input_str[y_hat_number - 1]
|
||
|
y_hat_number += 1
|
||
|
ner_json={'tag':ner_ch_name,'term':ner_words}
|
||
|
ner_json_list.append(ner_json)
|
||
|
#print("{}:{}".format(ner_ch_name, ner_words))
|
||
|
continue
|
||
|
result_ner_json_list.append(ner_json_list)
|
||
|
time2 = time.time()
|
||
|
print("ner panjueshu main predict time:{}s".format(time2 - time1))
|
||
|
return {'code':200,'data':result_ner_json_list,'msg':""}
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu main predict error:{}".format(error),exc_info=True)
|
||
|
return {'code': 500, 'data': [], 'msg': error}
|