You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
10 KiB
213 lines
10 KiB
1 year ago
|
import torch
|
||
|
from common.config import config_class
|
||
|
import os
|
||
|
import logging.handlers
|
||
|
from concurrent_log_handler import ConcurrentRotatingFileHandler
|
||
|
from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel
|
||
|
from torch.utils.data import SequentialSampler, TensorDataset, DataLoader
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from ner_1.BertSpanForNer import BertSpanForNer
|
||
|
from ner_1.InputExample import InputExample
|
||
|
from ner_1.InputExample import InputFeature
|
||
|
|
||
|
try:
|
||
|
log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||
|
formatter = logging.Formatter(log_fmt)
|
||
|
handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8')
|
||
|
handler.setFormatter(formatter)
|
||
|
logging.basicConfig(level=logging.DEBUG)
|
||
|
log = logging.getLogger(__name__)
|
||
|
log.addHandler(handler)
|
||
|
except Exception as error:
|
||
|
result_dict = {"code": 500,"error_msg":"日志文件打开失败"}
|
||
|
|
||
|
class panjueshu():
|
||
|
def __init__(self):
|
||
|
try:
|
||
|
model_name="判决书"
|
||
|
config_con=config_class(model_name)
|
||
|
self.label_json=config_con.reload_ner_tag_json()
|
||
|
self.ner_tag_list=config_con.reload_ner_tag()
|
||
|
self.id2label = {i: label for i, label in enumerate(self.ner_tag_list)}
|
||
|
self.device=torch.device("cuda")
|
||
|
self.tokenizer = BertTokenizer.from_pretrained(config_con.model_path, do_lower_case=True,ignore_mismatched_sizes=True)
|
||
|
config = BertConfig.from_pretrained(config_con.model_path, num_labels=len(self.ner_tag_list),ignore_mismatched_sizes=True)
|
||
|
self.model = BertSpanForNer.from_pretrained(config_con.model_path, config=config,ignore_mismatched_sizes=True)
|
||
|
self.model.to(self.device)
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu main __init__ error:{}".format(error),exc_info=True)
|
||
|
def load_and_cache_examples(self,tokenizer, f_lines):
|
||
|
try:
|
||
|
examples = self.create_examples(self.read_text(f_lines))
|
||
|
features = self.convert_examples_to_features(examples=examples,
|
||
|
tokenizer=tokenizer,
|
||
|
max_seq_length=512,
|
||
|
cls_token_at_end=False,
|
||
|
pad_on_left=False,
|
||
|
cls_token=tokenizer.cls_token,
|
||
|
cls_token_segment_id=0,
|
||
|
sep_token=tokenizer.sep_token,
|
||
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||
|
pad_token_segment_id=0)
|
||
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||
|
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||
|
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||
|
all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long)
|
||
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_input_lens)
|
||
|
return dataset, examples
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu load_and_cache_examples error:{}".format(error),exc_info=True)
|
||
|
return None,None
|
||
|
def create_examples(self,lines):
|
||
|
try:
|
||
|
examples = []
|
||
|
for (i, line) in enumerate(lines):
|
||
|
text_a = line['words']
|
||
|
labels = []
|
||
|
for x in line['labels']:
|
||
|
labels.append(x)
|
||
|
examples.append(InputExample(text_a=text_a))
|
||
|
return examples
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu create_examples error:{}".format(error),exc_info=True)
|
||
|
return None
|
||
|
|
||
|
def read_text(self,f_lines):
|
||
|
try:
|
||
|
lines = []
|
||
|
words = []
|
||
|
labels = []
|
||
|
count = 0
|
||
|
for char in f_lines:
|
||
|
if char != '\n' and char.strip() == "":
|
||
|
continue
|
||
|
if char == '\n':
|
||
|
lines.append({"words": words, "labels": labels})
|
||
|
count = 0
|
||
|
words = []
|
||
|
labels = []
|
||
|
continue
|
||
|
if count >= 510:
|
||
|
lines.append({"words": words, "labels": labels})
|
||
|
count = 0
|
||
|
words = []
|
||
|
labels = []
|
||
|
words.append(char.strip())
|
||
|
labels.append("O")
|
||
|
count = count + 1
|
||
|
if words:
|
||
|
lines.append({"words": words, "labels": labels})
|
||
|
return lines
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu read_text error:{}".format(error),exc_info=True)
|
||
|
return None
|
||
|
|
||
|
def convert_examples_to_features(self,examples, max_seq_length, tokenizer,
|
||
|
cls_token_at_end=False, cls_token="[CLS]", cls_token_segment_id=1,
|
||
|
sep_token="[SEP]", pad_on_left=False, pad_token=0, pad_token_segment_id=0,
|
||
|
sequence_a_segment_id=0, mask_padding_with_zero=True):
|
||
|
try:
|
||
|
features = []
|
||
|
for (ex_index, example) in enumerate(examples):
|
||
|
textlist = example.text_a
|
||
|
if isinstance(textlist, list):
|
||
|
textlist = " ".join(textlist)
|
||
|
tokens = tokenizer.tokenize(textlist)
|
||
|
special_tokens_count = 2
|
||
|
if len(tokens) > max_seq_length - special_tokens_count:
|
||
|
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
||
|
tokens += [sep_token]
|
||
|
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||
|
if cls_token_at_end:
|
||
|
tokens += [cls_token]
|
||
|
segment_ids += [cls_token_segment_id]
|
||
|
else:
|
||
|
tokens = [cls_token] + tokens
|
||
|
segment_ids = [cls_token_segment_id] + segment_ids
|
||
|
|
||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||
|
input_len = len(input_ids)
|
||
|
padding_length = max_seq_length - len(input_ids)
|
||
|
if pad_on_left:
|
||
|
input_ids = ([pad_token] * padding_length) + input_ids
|
||
|
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||
|
else:
|
||
|
input_ids += [pad_token] * padding_length
|
||
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||
|
segment_ids += [pad_token_segment_id] * padding_length
|
||
|
|
||
|
assert len(input_ids) == max_seq_length
|
||
|
assert len(input_mask) == max_seq_length
|
||
|
assert len(segment_ids) == max_seq_length
|
||
|
|
||
|
features.append(InputFeature(input_ids=input_ids,
|
||
|
input_mask=input_mask,
|
||
|
segment_ids=segment_ids,
|
||
|
input_len=input_len))
|
||
|
return features
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu convert_examples_to_features error:{}".format(error), exc_info=True)
|
||
|
return None
|
||
|
|
||
|
def bert_extract_item(self,start_logits, end_logits):
|
||
|
try:
|
||
|
S = []
|
||
|
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1]
|
||
|
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1]
|
||
|
for i, s_l in enumerate(start_pred):
|
||
|
if s_l == 0:
|
||
|
continue
|
||
|
for j, e_l in enumerate(end_pred[i:]):
|
||
|
if s_l == e_l:
|
||
|
S.append((s_l, i, i + j))
|
||
|
break
|
||
|
return S
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu bert_extract_item error:{}".format(error), exc_info=True)
|
||
|
return None
|
||
|
|
||
|
def collate_fn(self, batch):
|
||
|
try:
|
||
|
all_input_ids, all_input_mask, all_segment_ids, all_lens = map(torch.stack, zip(*batch))
|
||
|
max_len = max(all_lens).item()
|
||
|
all_input_ids = all_input_ids[:, :max_len]
|
||
|
all_input_mask = all_input_mask[:, :max_len]
|
||
|
all_segment_ids = all_segment_ids[:, :max_len]
|
||
|
return all_input_ids, all_input_mask, all_segment_ids, all_lens
|
||
|
except Exception as error:
|
||
|
log.error("ner qisushu collate_fn error:{}".format(error), exc_info=True)
|
||
|
return None, None, None, None
|
||
|
|
||
|
def predict(self,text_list):
|
||
|
try:
|
||
|
ner_json_list_list=[]
|
||
|
for text in text_list:
|
||
|
ner_json_list=[]
|
||
|
test_dataset, examples = self.load_and_cache_examples(self.tokenizer, text)
|
||
|
test_sampler = SequentialSampler(test_dataset)
|
||
|
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=1, collate_fn=self.collate_fn)
|
||
|
for step, batch in enumerate(test_dataloader):
|
||
|
self.model.eval()
|
||
|
batch = tuple(t.to(self.device) for t in batch)
|
||
|
with torch.no_grad():
|
||
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2]}
|
||
|
outputs = self.model(**inputs)
|
||
|
start_logits, end_logits = outputs[:2]
|
||
|
R = self.bert_extract_item(start_logits, end_logits)
|
||
|
if R:
|
||
|
label_entities = [[self.id2label[x[0]], x[1], x[2]] for x in R]
|
||
|
else:
|
||
|
label_entities = []
|
||
|
for entity in label_entities:
|
||
|
start = entity[1]
|
||
|
end = entity[2]
|
||
|
text = "".join(examples[step].text_a[start:end + 1])
|
||
|
ner_json_list.append({"tag": self.label_json[entity[0]], "term": text})
|
||
|
ner_json_list_list.append(ner_json_list)
|
||
|
return ner_json_list_list
|
||
|
except Exception as error:
|
||
|
log.error("ner panjueshu predict error:{}".format(error), exc_info=True)
|
||
|
return None
|