You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
3.3 KiB
68 lines
3.3 KiB
from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import CrossEntropyLoss |
|
from ner_1.linears import PoolerEndLogits, PoolerStartLogits |
|
from ner_1.label_smoothing import LabelSmoothingCrossEntropy |
|
from ner_1.focal_loss import FocalLoss |
|
|
|
class BertSpanForNer(BertPreTrainedModel): |
|
def __init__(self, config,): |
|
super(BertSpanForNer, self).__init__(config) |
|
self.soft_label = config.soft_label |
|
self.num_labels = config.num_labels |
|
self.loss_type = config.loss_type |
|
self.bert = BertModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels) |
|
if self.soft_label: |
|
self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels) |
|
else: |
|
self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels) |
|
self.init_weights() |
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None): |
|
outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) |
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
start_logits = self.start_fc(sequence_output) |
|
if start_positions is not None and self.training: |
|
if self.soft_label: |
|
batch_size = input_ids.size(0) |
|
seq_len = input_ids.size(1) |
|
label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels) |
|
label_logits.zero_() |
|
label_logits = label_logits.to(input_ids.device) |
|
label_logits.scatter_(2, start_positions.unsqueeze(2), 1) |
|
else: |
|
label_logits = start_positions.unsqueeze(2).float() |
|
else: |
|
label_logits = F.softmax(start_logits, -1) |
|
if not self.soft_label: |
|
label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float() |
|
end_logits = self.end_fc(sequence_output, label_logits) |
|
outputs = (start_logits, end_logits,) + outputs[2:] |
|
|
|
if start_positions is not None and end_positions is not None: |
|
assert self.loss_type in ['lsr', 'focal', 'ce'] |
|
if self.loss_type =='lsr': |
|
loss_fct = LabelSmoothingCrossEntropy() |
|
elif self.loss_type == 'focal': |
|
loss_fct = FocalLoss() |
|
else: |
|
loss_fct = CrossEntropyLoss() |
|
start_logits = start_logits.view(-1, self.num_labels) |
|
end_logits = end_logits.view(-1, self.num_labels) |
|
active_loss = attention_mask.view(-1) == 1 |
|
active_start_logits = start_logits[active_loss] |
|
active_end_logits = end_logits[active_loss] |
|
|
|
active_start_labels = start_positions.view(-1)[active_loss] |
|
active_end_labels = end_positions.view(-1)[active_loss] |
|
|
|
start_loss = loss_fct(active_start_logits, active_start_labels) |
|
end_loss = loss_fct(active_end_logits, active_end_labels) |
|
total_loss = (start_loss + end_loss) / 2 |
|
outputs = (total_loss,) + outputs |
|
return outputs |