You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.4 KiB
40 lines
1.4 KiB
1 year ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
class FeedForwardNetwork(nn.Module):
|
||
|
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0):
|
||
|
super(FeedForwardNetwork, self).__init__()
|
||
|
self.dropout_rate = dropout_rate
|
||
|
self.linear1 = nn.Linear(input_size, hidden_size)
|
||
|
self.linear2 = nn.Linear(hidden_size, output_size)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x_proj = F.dropout(F.relu(self.linear1(x)), p=self.dropout_rate, training=self.training)
|
||
|
x_proj = self.linear2(x_proj)
|
||
|
return x_proj
|
||
|
|
||
|
class PoolerStartLogits(nn.Module):
|
||
|
def __init__(self, hidden_size, num_classes):
|
||
|
super(PoolerStartLogits, self).__init__()
|
||
|
self.dense = nn.Linear(hidden_size, num_classes)
|
||
|
|
||
|
def forward(self, hidden_states, p_mask=None):
|
||
|
x = self.dense(hidden_states)
|
||
|
return x
|
||
|
|
||
|
class PoolerEndLogits(nn.Module):
|
||
|
def __init__(self, hidden_size, num_classes):
|
||
|
super(PoolerEndLogits, self).__init__()
|
||
|
self.dense_0 = nn.Linear(hidden_size, hidden_size)
|
||
|
self.activation = nn.Tanh()
|
||
|
self.LayerNorm = nn.LayerNorm(hidden_size)
|
||
|
self.dense_1 = nn.Linear(hidden_size, num_classes)
|
||
|
|
||
|
def forward(self, hidden_states, start_positions=None, p_mask=None):
|
||
|
x = self.dense_0(torch.cat([hidden_states, start_positions], dim=-1))
|
||
|
x = self.activation(x)
|
||
|
x = self.LayerNorm(x)
|
||
|
x = self.dense_1(x)
|
||
|
return x
|