| | from transformers import RobertaConfig |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| | from transformers.models.roberta.modeling_roberta import RobertaModel |
| | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel |
| | import torch.nn as nn |
| | import torch |
| |
|
| |
|
| | class DependencyRobertaForTokenClassification(RobertaPreTrainedModel): |
| | config_class = RobertaConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | |
| | self.roberta = RobertaModel(config, add_pooling_layer=False) |
| | self.dropout = nn.Dropout(p=0.35) |
| | self.u_a = nn.Linear(768, 768) |
| | self.w_a = nn.Linear(768, 768) |
| | self.v_a_inv = nn.Linear(768, 1, bias=False) |
| | self.criterion = nn.NLLLoss() |
| | |
| | self.init_weights() |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, |
| | labels=None, **kwargs): |
| | loss = 0.0 |
| | output = self.roberta(input_ids, attention_mask=attention_mask, |
| | token_type_ids=token_type_ids)[0] |
| | batch_size, seq_len, _ = output.size() |
| |
|
| | |
| | parent_prob_table = [] |
| | logits = [] |
| | for i in range(0, seq_len): |
| | |
| | target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1) |
| | mask = output.eq(target)[:, :, 0].unsqueeze(2) |
| | p_head = self.attention(output, target, mask) |
| | logits.append(p_head) |
| | if labels is not None: |
| | current_loss = self.criterion(p_head.squeeze(), kwargs["head_labels"][:, i]) |
| | if not torch.all(kwargs["head_labels"][:, i] == -100): |
| | loss += current_loss |
| | parent_prob_table.append(torch.exp(p_head)) |
| |
|
| | parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2) |
| | prob, topi = parent_prob_table.topk(k=1, dim=2) |
| | preds = topi.squeeze() |
| | if len(preds.shape) > 1: |
| | preds = nn.ConstantPad1d((0, 512 - preds.shape[1]), -100)(preds) |
| | loss = loss/seq_len |
| | output = TokenClassifierOutput(loss=loss, logits=preds) |
| | return output, parent_prob_table |
| |
|
| | def attention(self, source, target, mask=None): |
| | function_g = \ |
| | self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target))) |
| | if mask is not None: |
| | function_g.masked_fill_(mask, -1e4) |
| | return nn.functional.log_softmax(function_g, dim=1) |