You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

212 lines
10 KiB

import torch
from common.config import config_class
import logging.handlers
from concurrent_log_handler import ConcurrentRotatingFileHandler
from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel
from torch.utils.data import SequentialSampler, TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from ner_1.BertSpanForNer import BertSpanForNer
from ner_1.InputExample import InputExample
from ner_1.InputExample import InputFeature
try:
log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
formatter = logging.Formatter(log_fmt)
handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8')
handler.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
log.addHandler(handler)
except Exception as error:
result_dict = {"code": 500,"error_msg":"日志文件打开失败"}
class qisushu():
def __init__(self):
try:
model_name="起诉书"
config_con=config_class(model_name)
self.label_json=config_con.reload_ner_tag_json()
self.ner_tag_list=config_con.reload_ner_tag()
self.id2label = {i: label for i, label in enumerate(self.ner_tag_list)}
self.device=torch.device("cuda")
self.tokenizer = BertTokenizer.from_pretrained(config_con.model_path, do_lower_case=True,ignore_mismatched_sizes=True)
config = BertConfig.from_pretrained(config_con.model_path, num_labels=len(self.ner_tag_list),ignore_mismatched_sizes=True)
self.model = BertSpanForNer.from_pretrained(config_con.model_path, config=config,ignore_mismatched_sizes=True)
self.model.to(self.device)
except Exception as error:
log.error("ner qisushu main __init__ error:{}".format(error),exc_info=True)
def load_and_cache_examples(self,tokenizer, f_lines):
try:
examples = self.create_examples(self.read_text(f_lines))
features = self.convert_examples_to_features(examples=examples,
tokenizer=tokenizer,
max_seq_length=512,
cls_token_at_end=False,
pad_on_left=False,
cls_token=tokenizer.cls_token,
cls_token_segment_id=0,
sep_token=tokenizer.sep_token,
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=0)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_input_lens)
return dataset, examples
except Exception as error:
log.error("ner qisushu load_and_cache_examples error:{}".format(error),exc_info=True)
return None,None
def create_examples(self,lines):
try:
examples = []
for (i, line) in enumerate(lines):
text_a = line['words']
labels = []
for x in line['labels']:
labels.append(x)
examples.append(InputExample(text_a=text_a))
return examples
except Exception as error:
log.error("ner qisushu create_examples error:{}".format(error),exc_info=True)
return None
def read_text(self,f_lines):
try:
lines = []
words = []
labels = []
count = 0
for char in f_lines:
if char != '\n' and char.strip() == "":
continue
if char == '\n':
lines.append({"words": words, "labels": labels})
count = 0
words = []
labels = []
continue
if count >= 510:
lines.append({"words": words, "labels": labels})
count = 0
words = []
labels = []
words.append(char.strip())
labels.append("O")
count = count + 1
if words:
lines.append({"words": words, "labels": labels})
return lines
except Exception as error:
log.error("ner qisushu read_text error:{}".format(error),exc_info=True)
return None
def convert_examples_to_features(self,examples, max_seq_length, tokenizer,
cls_token_at_end=False, cls_token="[CLS]", cls_token_segment_id=1,
sep_token="[SEP]", pad_on_left=False, pad_token=0, pad_token_segment_id=0,
sequence_a_segment_id=0, mask_padding_with_zero=True):
try:
features = []
for (ex_index, example) in enumerate(examples):
textlist = example.text_a
if isinstance(textlist, list):
textlist = " ".join(textlist)
tokens = tokenizer.tokenize(textlist)
special_tokens_count = 2
if len(tokens) > max_seq_length - special_tokens_count:
tokens = tokens[: (max_seq_length - special_tokens_count)]
tokens += [sep_token]
segment_ids = [sequence_a_segment_id] * len(tokens)
if cls_token_at_end:
tokens += [cls_token]
segment_ids += [cls_token_segment_id]
else:
tokens = [cls_token] + tokens
segment_ids = [cls_token_segment_id] + segment_ids
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
input_len = len(input_ids)
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids += [pad_token] * padding_length
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
segment_ids += [pad_token_segment_id] * padding_length
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
features.append(InputFeature(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
input_len=input_len))
return features
except Exception as error:
log.error("ner qisushu convert_examples_to_features error:{}".format(error), exc_info=True)
return None
def bert_extract_item(self,start_logits, end_logits):
try:
S = []
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1]
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1]
for i, s_l in enumerate(start_pred):
if s_l == 0:
continue
for j, e_l in enumerate(end_pred[i:]):
if s_l == e_l:
S.append((s_l, i, i + j))
break
return S
except Exception as error:
log.error("ner qisushu bert_extract_item error:{}".format(error), exc_info=True)
return None
def collate_fn(self,batch):
try:
all_input_ids, all_input_mask, all_segment_ids, all_lens = map(torch.stack, zip(*batch))
max_len = max(all_lens).item()
all_input_ids = all_input_ids[:, :max_len]
all_input_mask = all_input_mask[:, :max_len]
all_segment_ids = all_segment_ids[:, :max_len]
return all_input_ids, all_input_mask, all_segment_ids, all_lens
except Exception as error:
log.error("ner qisushu collate_fn error:{}".format(error), exc_info=True)
return None,None,None,None
def predict(self,text_list):
try:
ner_json_list_list=[]
for text in text_list:
ner_json_list=[]
test_dataset, examples = self.load_and_cache_examples(self.tokenizer, text)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=1, collate_fn=self.collate_fn)
for step, batch in enumerate(test_dataloader):
self.model.eval()
batch = tuple(t.to(self.device) for t in batch)
with torch.no_grad():
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2]}
outputs = self.model(**inputs)
start_logits, end_logits = outputs[:2]
R = self.bert_extract_item(start_logits, end_logits)
if R:
label_entities = [[self.id2label[x[0]], x[1], x[2]] for x in R]
else:
label_entities = []
for entity in label_entities:
start = entity[1]
end = entity[2]
text = "".join(examples[step].text_a[start:end + 1])
ner_json_list.append({"tag": self.label_json[entity[0]], "term": text})
ner_json_list_list.append(ner_json_list)
return ner_json_list_list
except Exception as error:
log.error("ner qisushu predict error:{}".format(error), exc_info=True)
return None