You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.4 KiB
39 lines
1.4 KiB
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class FeedForwardNetwork(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0): |
|
super(FeedForwardNetwork, self).__init__() |
|
self.dropout_rate = dropout_rate |
|
self.linear1 = nn.Linear(input_size, hidden_size) |
|
self.linear2 = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
x_proj = F.dropout(F.relu(self.linear1(x)), p=self.dropout_rate, training=self.training) |
|
x_proj = self.linear2(x_proj) |
|
return x_proj |
|
|
|
class PoolerStartLogits(nn.Module): |
|
def __init__(self, hidden_size, num_classes): |
|
super(PoolerStartLogits, self).__init__() |
|
self.dense = nn.Linear(hidden_size, num_classes) |
|
|
|
def forward(self, hidden_states, p_mask=None): |
|
x = self.dense(hidden_states) |
|
return x |
|
|
|
class PoolerEndLogits(nn.Module): |
|
def __init__(self, hidden_size, num_classes): |
|
super(PoolerEndLogits, self).__init__() |
|
self.dense_0 = nn.Linear(hidden_size, hidden_size) |
|
self.activation = nn.Tanh() |
|
self.LayerNorm = nn.LayerNorm(hidden_size) |
|
self.dense_1 = nn.Linear(hidden_size, num_classes) |
|
|
|
def forward(self, hidden_states, start_positions=None, p_mask=None): |
|
x = self.dense_0(torch.cat([hidden_states, start_positions], dim=-1)) |
|
x = self.activation(x) |
|
x = self.LayerNorm(x) |
|
x = self.dense_1(x) |
|
return x
|
|
|