depenBERTa_perseus / modeling_depenberta.py
bowphs's picture
Update modeling_depenberta.py
6e26ea5
from transformers import RobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
import torch.nn as nn
import torch
class DependencyRobertaForTokenClassification(RobertaPreTrainedModel):
config_class = RobertaConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Load model body
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(p=0.35)
self.u_a = nn.Linear(768, 768)
self.w_a = nn.Linear(768, 768)
self.v_a_inv = nn.Linear(768, 1, bias=False)
self.criterion = nn.NLLLoss()
# Load and initialize weights
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
labels=None, **kwargs):
loss = 0.0
output = self.roberta(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]
batch_size, seq_len, _ = output.size()
parent_prob_table = []
logits = []
for i in range(0, seq_len):
target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1)
mask = output.eq(target)[:, :, 0].unsqueeze(2)
p_head = self.attention(output, target, mask)
logits.append(p_head)
if labels is not None:
current_loss = self.criterion(p_head.squeeze(), kwargs["head_labels"][:, i])
if not torch.all(kwargs["head_labels"][:, i] == -100):
loss += current_loss
parent_prob_table.append(torch.exp(p_head))
parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2)
prob, topi = parent_prob_table.topk(k=1, dim=2)
preds = topi.squeeze()
if len(preds.shape) > 1:
preds = nn.ConstantPad1d((0, 512 - preds.shape[1]), -100)(preds)
loss = loss/seq_len
output = TokenClassifierOutput(loss=loss, logits=preds)
return output, parent_prob_table
def attention(self, source, target, mask=None):
function_g = \
self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
if mask is not None:
function_g.masked_fill_(mask, -1e4)
return nn.functional.log_softmax(function_g, dim=1)