You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
10 KiB
213 lines
10 KiB
import torch |
|
from common.config import config_class |
|
import os |
|
import logging.handlers |
|
from concurrent_log_handler import ConcurrentRotatingFileHandler |
|
from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel |
|
from torch.utils.data import SequentialSampler, TensorDataset, DataLoader |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from ner_1.BertSpanForNer import BertSpanForNer |
|
from ner_1.InputExample import InputExample |
|
from ner_1.InputExample import InputFeature |
|
|
|
try: |
|
log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|
formatter = logging.Formatter(log_fmt) |
|
handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8') |
|
handler.setFormatter(formatter) |
|
logging.basicConfig(level=logging.DEBUG) |
|
log = logging.getLogger(__name__) |
|
log.addHandler(handler) |
|
except Exception as error: |
|
result_dict = {"code": 500,"error_msg":"日志文件打开失败"} |
|
|
|
class panjueshu(): |
|
def __init__(self): |
|
try: |
|
model_name="判决书" |
|
config_con=config_class(model_name) |
|
self.label_json=config_con.reload_ner_tag_json() |
|
self.ner_tag_list=config_con.reload_ner_tag() |
|
self.id2label = {i: label for i, label in enumerate(self.ner_tag_list)} |
|
self.device=torch.device("cuda") |
|
self.tokenizer = BertTokenizer.from_pretrained(config_con.model_path, do_lower_case=True,ignore_mismatched_sizes=True) |
|
config = BertConfig.from_pretrained(config_con.model_path, num_labels=len(self.ner_tag_list),ignore_mismatched_sizes=True) |
|
self.model = BertSpanForNer.from_pretrained(config_con.model_path, config=config,ignore_mismatched_sizes=True) |
|
self.model.to(self.device) |
|
except Exception as error: |
|
log.error("ner panjueshu main __init__ error:{}".format(error),exc_info=True) |
|
def load_and_cache_examples(self,tokenizer, f_lines): |
|
try: |
|
examples = self.create_examples(self.read_text(f_lines)) |
|
features = self.convert_examples_to_features(examples=examples, |
|
tokenizer=tokenizer, |
|
max_seq_length=512, |
|
cls_token_at_end=False, |
|
pad_on_left=False, |
|
cls_token=tokenizer.cls_token, |
|
cls_token_segment_id=0, |
|
sep_token=tokenizer.sep_token, |
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], |
|
pad_token_segment_id=0) |
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
|
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) |
|
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) |
|
all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long) |
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_input_lens) |
|
return dataset, examples |
|
except Exception as error: |
|
log.error("ner panjueshu load_and_cache_examples error:{}".format(error),exc_info=True) |
|
return None,None |
|
def create_examples(self,lines): |
|
try: |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
text_a = line['words'] |
|
labels = [] |
|
for x in line['labels']: |
|
labels.append(x) |
|
examples.append(InputExample(text_a=text_a)) |
|
return examples |
|
except Exception as error: |
|
log.error("ner panjueshu create_examples error:{}".format(error),exc_info=True) |
|
return None |
|
|
|
def read_text(self,f_lines): |
|
try: |
|
lines = [] |
|
words = [] |
|
labels = [] |
|
count = 0 |
|
for char in f_lines: |
|
if char != '\n' and char.strip() == "": |
|
continue |
|
if char == '\n': |
|
lines.append({"words": words, "labels": labels}) |
|
count = 0 |
|
words = [] |
|
labels = [] |
|
continue |
|
if count >= 510: |
|
lines.append({"words": words, "labels": labels}) |
|
count = 0 |
|
words = [] |
|
labels = [] |
|
words.append(char.strip()) |
|
labels.append("O") |
|
count = count + 1 |
|
if words: |
|
lines.append({"words": words, "labels": labels}) |
|
return lines |
|
except Exception as error: |
|
log.error("ner panjueshu read_text error:{}".format(error),exc_info=True) |
|
return None |
|
|
|
def convert_examples_to_features(self,examples, max_seq_length, tokenizer, |
|
cls_token_at_end=False, cls_token="[CLS]", cls_token_segment_id=1, |
|
sep_token="[SEP]", pad_on_left=False, pad_token=0, pad_token_segment_id=0, |
|
sequence_a_segment_id=0, mask_padding_with_zero=True): |
|
try: |
|
features = [] |
|
for (ex_index, example) in enumerate(examples): |
|
textlist = example.text_a |
|
if isinstance(textlist, list): |
|
textlist = " ".join(textlist) |
|
tokens = tokenizer.tokenize(textlist) |
|
special_tokens_count = 2 |
|
if len(tokens) > max_seq_length - special_tokens_count: |
|
tokens = tokens[: (max_seq_length - special_tokens_count)] |
|
tokens += [sep_token] |
|
segment_ids = [sequence_a_segment_id] * len(tokens) |
|
if cls_token_at_end: |
|
tokens += [cls_token] |
|
segment_ids += [cls_token_segment_id] |
|
else: |
|
tokens = [cls_token] + tokens |
|
segment_ids = [cls_token_segment_id] + segment_ids |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) |
|
input_len = len(input_ids) |
|
padding_length = max_seq_length - len(input_ids) |
|
if pad_on_left: |
|
input_ids = ([pad_token] * padding_length) + input_ids |
|
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask |
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids |
|
else: |
|
input_ids += [pad_token] * padding_length |
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length |
|
segment_ids += [pad_token_segment_id] * padding_length |
|
|
|
assert len(input_ids) == max_seq_length |
|
assert len(input_mask) == max_seq_length |
|
assert len(segment_ids) == max_seq_length |
|
|
|
features.append(InputFeature(input_ids=input_ids, |
|
input_mask=input_mask, |
|
segment_ids=segment_ids, |
|
input_len=input_len)) |
|
return features |
|
except Exception as error: |
|
log.error("ner panjueshu convert_examples_to_features error:{}".format(error), exc_info=True) |
|
return None |
|
|
|
def bert_extract_item(self,start_logits, end_logits): |
|
try: |
|
S = [] |
|
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] |
|
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] |
|
for i, s_l in enumerate(start_pred): |
|
if s_l == 0: |
|
continue |
|
for j, e_l in enumerate(end_pred[i:]): |
|
if s_l == e_l: |
|
S.append((s_l, i, i + j)) |
|
break |
|
return S |
|
except Exception as error: |
|
log.error("ner panjueshu bert_extract_item error:{}".format(error), exc_info=True) |
|
return None |
|
|
|
def collate_fn(self, batch): |
|
try: |
|
all_input_ids, all_input_mask, all_segment_ids, all_lens = map(torch.stack, zip(*batch)) |
|
max_len = max(all_lens).item() |
|
all_input_ids = all_input_ids[:, :max_len] |
|
all_input_mask = all_input_mask[:, :max_len] |
|
all_segment_ids = all_segment_ids[:, :max_len] |
|
return all_input_ids, all_input_mask, all_segment_ids, all_lens |
|
except Exception as error: |
|
log.error("ner qisushu collate_fn error:{}".format(error), exc_info=True) |
|
return None, None, None, None |
|
|
|
def predict(self,text_list): |
|
try: |
|
ner_json_list_list=[] |
|
for text in text_list: |
|
ner_json_list=[] |
|
test_dataset, examples = self.load_and_cache_examples(self.tokenizer, text) |
|
test_sampler = SequentialSampler(test_dataset) |
|
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=1, collate_fn=self.collate_fn) |
|
for step, batch in enumerate(test_dataloader): |
|
self.model.eval() |
|
batch = tuple(t.to(self.device) for t in batch) |
|
with torch.no_grad(): |
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2]} |
|
outputs = self.model(**inputs) |
|
start_logits, end_logits = outputs[:2] |
|
R = self.bert_extract_item(start_logits, end_logits) |
|
if R: |
|
label_entities = [[self.id2label[x[0]], x[1], x[2]] for x in R] |
|
else: |
|
label_entities = [] |
|
for entity in label_entities: |
|
start = entity[1] |
|
end = entity[2] |
|
text = "".join(examples[step].text_a[start:end + 1]) |
|
ner_json_list.append({"tag": self.label_json[entity[0]], "term": text}) |
|
ner_json_list_list.append(ner_json_list) |
|
return ner_json_list_list |
|
except Exception as error: |
|
log.error("ner panjueshu predict error:{}".format(error), exc_info=True) |
|
return None |