import torch from common.config import config_class import logging.handlers from concurrent_log_handler import ConcurrentRotatingFileHandler from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel from torch.utils.data import SequentialSampler, TensorDataset, DataLoader import torch.nn as nn import torch.nn.functional as F from ner_1.BertSpanForNer import BertSpanForNer from ner_1.InputExample import InputExample from ner_1.InputExample import InputFeature try: log_fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' formatter = logging.Formatter(log_fmt) handler = ConcurrentRotatingFileHandler('./logs/log.log', maxBytes=10000000, backupCount=10,encoding='utf-8') handler.setFormatter(formatter) logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__) log.addHandler(handler) except Exception as error: result_dict = {"code": 500,"error_msg":"日志文件打开失败"} class qisushu(): def __init__(self): try: model_name="起诉书" config_con=config_class(model_name) self.label_json=config_con.reload_ner_tag_json() self.ner_tag_list=config_con.reload_ner_tag() self.id2label = {i: label for i, label in enumerate(self.ner_tag_list)} self.device=torch.device("cuda") self.tokenizer = BertTokenizer.from_pretrained(config_con.model_path, do_lower_case=True,ignore_mismatched_sizes=True) config = BertConfig.from_pretrained(config_con.model_path, num_labels=len(self.ner_tag_list),ignore_mismatched_sizes=True) self.model = BertSpanForNer.from_pretrained(config_con.model_path, config=config,ignore_mismatched_sizes=True) self.model.to(self.device) except Exception as error: log.error("ner qisushu main __init__ error:{}".format(error),exc_info=True) def load_and_cache_examples(self,tokenizer, f_lines): try: examples = self.create_examples(self.read_text(f_lines)) features = self.convert_examples_to_features(examples=examples, tokenizer=tokenizer, max_seq_length=512, cls_token_at_end=False, pad_on_left=False, cls_token=tokenizer.cls_token, cls_token_segment_id=0, sep_token=tokenizer.sep_token, pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token_segment_id=0) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_input_lens) return dataset, examples except Exception as error: log.error("ner qisushu load_and_cache_examples error:{}".format(error),exc_info=True) return None,None def create_examples(self,lines): try: examples = [] for (i, line) in enumerate(lines): text_a = line['words'] labels = [] for x in line['labels']: labels.append(x) examples.append(InputExample(text_a=text_a)) return examples except Exception as error: log.error("ner qisushu create_examples error:{}".format(error),exc_info=True) return None def read_text(self,f_lines): try: lines = [] words = [] labels = [] count = 0 for char in f_lines: if char != '\n' and char.strip() == "": continue if char == '\n': lines.append({"words": words, "labels": labels}) count = 0 words = [] labels = [] continue if count >= 510: lines.append({"words": words, "labels": labels}) count = 0 words = [] labels = [] words.append(char.strip()) labels.append("O") count = count + 1 if words: lines.append({"words": words, "labels": labels}) return lines except Exception as error: log.error("ner qisushu read_text error:{}".format(error),exc_info=True) return None def convert_examples_to_features(self,examples, max_seq_length, tokenizer, cls_token_at_end=False, cls_token="[CLS]", cls_token_segment_id=1, sep_token="[SEP]", pad_on_left=False, pad_token=0, pad_token_segment_id=0, sequence_a_segment_id=0, mask_padding_with_zero=True): try: features = [] for (ex_index, example) in enumerate(examples): textlist = example.text_a if isinstance(textlist, list): textlist = " ".join(textlist) tokens = tokenizer.tokenize(textlist) special_tokens_count = 2 if len(tokens) > max_seq_length - special_tokens_count: tokens = tokens[: (max_seq_length - special_tokens_count)] tokens += [sep_token] segment_ids = [sequence_a_segment_id] * len(tokens) if cls_token_at_end: tokens += [cls_token] segment_ids += [cls_token_segment_id] else: tokens = [cls_token] + tokens segment_ids = [cls_token_segment_id] + segment_ids input_ids = tokenizer.convert_tokens_to_ids(tokens) input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) input_len = len(input_ids) padding_length = max_seq_length - len(input_ids) if pad_on_left: input_ids = ([pad_token] * padding_length) + input_ids input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids else: input_ids += [pad_token] * padding_length input_mask += [0 if mask_padding_with_zero else 1] * padding_length segment_ids += [pad_token_segment_id] * padding_length assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length features.append(InputFeature(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, input_len=input_len)) return features except Exception as error: log.error("ner qisushu convert_examples_to_features error:{}".format(error), exc_info=True) return None def bert_extract_item(self,start_logits, end_logits): try: S = [] start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] for i, s_l in enumerate(start_pred): if s_l == 0: continue for j, e_l in enumerate(end_pred[i:]): if s_l == e_l: S.append((s_l, i, i + j)) break return S except Exception as error: log.error("ner qisushu bert_extract_item error:{}".format(error), exc_info=True) return None def collate_fn(self,batch): try: all_input_ids, all_input_mask, all_segment_ids, all_lens = map(torch.stack, zip(*batch)) max_len = max(all_lens).item() all_input_ids = all_input_ids[:, :max_len] all_input_mask = all_input_mask[:, :max_len] all_segment_ids = all_segment_ids[:, :max_len] return all_input_ids, all_input_mask, all_segment_ids, all_lens except Exception as error: log.error("ner qisushu collate_fn error:{}".format(error), exc_info=True) return None,None,None,None def predict(self,text_list): try: ner_json_list_list=[] for text in text_list: ner_json_list=[] test_dataset, examples = self.load_and_cache_examples(self.tokenizer, text) test_sampler = SequentialSampler(test_dataset) test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=1, collate_fn=self.collate_fn) for step, batch in enumerate(test_dataloader): self.model.eval() batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2]} outputs = self.model(**inputs) start_logits, end_logits = outputs[:2] R = self.bert_extract_item(start_logits, end_logits) if R: label_entities = [[self.id2label[x[0]], x[1], x[2]] for x in R] else: label_entities = [] for entity in label_entities: start = entity[1] end = entity[2] text = "".join(examples[step].text_a[start:end + 1]) ner_json_list.append({"tag": self.label_json[entity[0]], "term": text}) ner_json_list_list.append(ner_json_list) return ner_json_list_list except Exception as error: log.error("ner qisushu predict error:{}".format(error), exc_info=True) return None