| import torch |
| from torch import nn |
| from transformers import AutoTokenizer, AutoModel |
| from huggingface_hub import hf_hub_download |
|
|
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
| class MLP(nn.Module): |
| def __init__(self, dim0, dim1): |
| super().__init__() |
| self.linear1 = nn.Linear(dim0, dim0) |
| self.linear2 = nn.Linear(dim0, dim1) |
| self.activate = nn.ReLU() |
| def forward(self, x): |
| return self.linear2(self.activate(self.linear1(x))) |
|
|
| class SpanInteraction(nn.Module): |
| def __init__(self, dim2): |
| super().__init__() |
| self.conv = nn.Conv2d(dim2, dim2, kernel_size=(5,5), padding=(2,2)) |
| def forward(self, hM): |
| hM = hM.permute(0,3,1,2) |
| hM = self.conv(hM) |
| return hM.permute(0,2,3,1) |
|
|
| class SpanEnumeration(nn.Module): |
| def __init__(self, dim1, dim2, max_len): |
| super().__init__() |
| self.s_mapping = nn.Linear(dim1, dim2) |
| self.e_mapping = nn.Linear(dim1, dim2) |
| self.pos_embedding = nn.Embedding(max_len, dim2) |
| self.layer_norm = nn.LayerNorm(dim2, eps=1e-12) |
| pos_id = [] |
| for i in range(max_len): |
| for j in range(max_len): |
| pos_id.append(abs(j - i)) |
| self.register_buffer("pos_id", torch.tensor(pos_id, dtype=torch.long)) |
| self.dim2 = dim2 |
| self.max_len = max_len |
| def forward(self, B_s, B_e): |
| bs, seq_len, _ = B_s.size() |
| pos_embedding = self.pos_embedding(self.pos_id).view(self.max_len, self.max_len, self.dim2) |
| pos_embedding = pos_embedding[:seq_len, :seq_len, :].reshape(seq_len, seq_len, self.dim2) |
| pos_embedding = pos_embedding.unsqueeze(0).expand(bs, seq_len, seq_len, self.dim2) |
| B_s = self.s_mapping(B_s) |
| B_e = self.s_mapping(B_e) |
| B_s_ex = B_s.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2) |
| B_e_ex = B_e.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2).transpose(1, 2) |
| N = B_s_ex + B_e_ex + pos_embedding |
| return self.layer_norm(N) |
|
|
| class BoundaryEnumeration(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.s_boundary_enum = MLP(dim, dim) |
| self.e_boundary_enum = MLP(dim, dim) |
| def forward(self, H_c): |
| return self.s_boundary_enum(H_c), self.e_boundary_enum(H_c) |
|
|
| class SpanScoring(nn.Module): |
| def __init__(self, dim1, dim2, max_len, max_span_gap): |
| super().__init__() |
| self.mlp_scoring = MLP(dim2, 1) |
| self.mlp_cls = MLP(dim1, 1) |
| tri = [] |
| for i in range(max_len): |
| for j in range(max_len): |
| tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0) |
| self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len)) |
| def forward(self, M, H_cls, masks): |
| S = self.mlp_scoring(M).view(M.size(0), M.size(1), M.size(2)) |
| qs = self.mlp_cls(H_cls) |
| bs, seq_len = S.size(0), S.size(1) |
| masks_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len) |
| masks_ex_t = masks_ex.transpose(1, 2) |
| masks_ex = masks_ex * masks_ex_t |
| tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len) |
| S = S - 10000.0 * (1 - masks_ex * tri) |
| return S, qs |
|
|
| class BoundaryAggregation(nn.Module): |
| def __init__(self, dim1, dim2, max_len, max_span_gap): |
| super().__init__() |
| self.span_enum_s = SpanEnumeration(dim1, dim2, max_len) |
| self.span_enum_e = SpanEnumeration(dim1, dim2, max_len) |
| self.span_scoring_s = SpanScoring(dim1, dim2, max_len, max_span_gap) |
| self.span_scoring_e = SpanScoring(dim1, dim2, max_len, max_span_gap) |
| self.W2_s = nn.Linear(dim1, dim1) |
| self.W2_e = nn.Linear(dim1, dim1) |
| self.span_interaction_s = SpanInteraction(dim2) |
| self.span_interaction_e = SpanInteraction(dim2) |
| def forward(self, hB_s, hB_e, H_cls, masks): |
| bs, seq_len, dim = hB_s.size() |
| M_s = self.span_enum_s(hB_s, hB_e) |
| M_s = self.span_interaction_s(M_s) |
| G_s, qs_s = self.span_scoring_s(M_s, H_cls, masks) |
| G_s_soft = torch.softmax(G_s, dim=-1) |
| B_s = torch.matmul(G_s_soft, self.W2_s(hB_s)).view(bs, seq_len, dim) |
| M_e = self.span_enum_e(hB_s, hB_e) |
| M_e = self.span_interaction_e(M_e) |
| G_e, qs_e = self.span_scoring_e(M_e, H_cls, masks) |
| G_e_soft = torch.softmax(G_e.transpose(-2, -1), dim=-1) |
| B_e = torch.matmul(G_e_soft, self.W2_e(hB_e)).view(bs, seq_len, dim) |
| return B_s, B_e, G_s, G_e, qs_s, qs_e |
|
|
| class SpanRepresentation(nn.Module): |
| def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False): |
| super().__init__() |
| self.span_enum = SpanEnumeration(dim1, dim2, max_len) |
| self.span_interaction = SpanInteraction(dim2) |
| self.vanilla = vanilla |
| tri = [] |
| for i in range(max_len): |
| for j in range(max_len): |
| tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0) |
| self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len)) |
| def forward(self, B_s, B_e, masks): |
| M = self.span_enum(B_s, B_e) |
| bs, seq_len, _ = B_s.size() |
| masks_c_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len) |
| masks_c_ex_t = masks_c_ex.transpose(1, 2) |
| masks_c_ex = masks_c_ex * masks_c_ex_t |
| tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len) |
| M = M * (masks_c_ex * tri).unsqueeze(3) |
| if not self.vanilla: |
| M = self.span_interaction(M) |
| return M |
|
|
| class BoundaryRepresentation(nn.Module): |
| def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False): |
| super().__init__() |
| self.boundary_enum = BoundaryEnumeration(dim1) |
| self.vanilla = vanilla |
| self.boundary_aggregation = BoundaryAggregation(dim1, dim2, max_len, max_span_gap) |
| def forward(self, H_c, H_cls, masks): |
| B_s, B_e = self.boundary_enum(H_c) |
| G_s = G_e = qs_s = qs_e = None |
| if not self.vanilla: |
| B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_aggregation(B_s, B_e, H_cls, masks) |
| return B_s, B_e, G_s, G_e, qs_s, qs_e |
|
|
| class SpanQualifier(nn.Module): |
| def __init__(self, base_model_name, max_span_gap=15, dim2=64, max_len=512, vanilla=False, force_answer=False): |
| super().__init__() |
| self.token_representation = AutoModel.from_pretrained(base_model_name) |
| dim1 = self.token_representation.config.hidden_size |
| self.boundary_representation = BoundaryRepresentation(dim1, dim2, max_len, max_span_gap, vanilla) |
| self.span_representation = SpanRepresentation(dim1, dim2, max_len, max_span_gap, vanilla) |
| self.span_scoring = SpanScoring(dim1, dim2, max_len, max_span_gap) |
| self.force_answer = force_answer |
| def forward(self, input_ids, type_ids, mask_ids, context_ranges): |
| outputs = self.token_representation( |
| input_ids=input_ids, attention_mask=mask_ids, token_type_ids=type_ids, |
| output_hidden_states=True, return_dict=True |
| ) |
| sequence_output = outputs.hidden_states[-1] |
| H_cls = sequence_output[:, 0, :] |
| H_c, masks = split_sequence_like(sequence_output, context_ranges) |
| B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_representation(H_c, H_cls, masks) |
| M = self.span_representation(B_s, B_e, masks) |
| S, qs_ext = self.span_scoring(M, H_cls, masks) |
| spans, _ = self.decoding_span_matrix(S, qs_ext) |
| return spans |
| def decoding_span_matrix(self, logits_matrix, threshold_p, spans_matrix_mask=None): |
| bs, seq_len, _ = logits_matrix.size() |
| if spans_matrix_mask is not None: |
| logits_matrix = logits_matrix - 10000.0 * spans_matrix_mask |
| logits_end = torch.softmax(logits_matrix, dim=2) |
| _, idx_best_end = torch.max(logits_end, dim=2) |
| idx_best_end = idx_best_end.cpu().tolist() |
| threshold_p = threshold_p.view(bs).cpu().tolist() |
| logits_beg = torch.softmax(logits_matrix, dim=1) |
| _, idx_best_beg = torch.max(logits_beg, dim=1) |
| idx_best_beg = idx_best_beg.cpu().tolist() |
| logits_matrix = logits_matrix.cpu().tolist() |
| spans = [] |
| for b_i, (matrix, t_p) in enumerate(zip(logits_matrix, threshold_p)): |
| spans_item = [] |
| max_logit, max_i, max_j = -10000, 0, 0 |
| for i, row in enumerate(matrix): |
| for j, logit in enumerate(row): |
| if i <= j and idx_best_end[b_i][i] == j and idx_best_beg[b_i][j] == i: |
| if logit > t_p: |
| spans_item.append([i, j]) |
| if logit > max_logit: |
| max_logit, max_i, max_j = logit, i, j |
| if len(spans_item) == 0 and self.force_answer: |
| spans_item.append([max_i, max_j]) |
| spans.append(spans_item) |
| return spans, None |
|
|
| def split_sequence_like(sequence_output, context_ranges): |
| """Packs context tokens to the front (like your split_sequence with useSep=False), |
| returns padded H_c and context masks.""" |
| bs, L, H = sequence_output.size() |
| H_c_batch, masks_batch = [], [] |
| for b in range(bs): |
| c_beg, c_end = context_ranges[b] |
| ctx = sequence_output[b, c_beg:c_end+1, :] |
| Lc = ctx.size(0) |
| H_c_pad = torch.zeros(L, H, device=sequence_output.device) |
| H_c_pad[:Lc] = ctx |
| mask = torch.zeros(L, device=sequence_output.device) |
| mask[:Lc] = 1 |
| H_c_batch.append(H_c_pad) |
| masks_batch.append(mask) |
| return torch.stack(H_c_batch, 0), torch.stack(masks_batch, 0) |
|
|
| def build_single_features(tokenizer, question, context, max_len=512, device=DEVICE): |
| enc = tokenizer(question, context, return_tensors="pt", |
| return_offsets_mapping=True, truncation=True, max_length=max_len) |
| input_ids = enc["input_ids"].to(device) |
| attn_mask = enc["attention_mask"].to(device) |
| type_ids = enc.get("token_type_ids") |
| if type_ids is None: |
| type_ids = torch.zeros_like(input_ids) |
| type_ids = type_ids.to(device) |
|
|
| |
| seq_ids = tokenizer(question, context, return_offsets_mapping=True, |
| truncation=True, max_length=max_len).sequence_ids(0) |
| ctx_start = seq_ids.index(1) |
| ctx_end = len(seq_ids) - 1 |
| for i in range(ctx_start, len(seq_ids)): |
| if seq_ids[i] is None: |
| ctx_end = i - 1 |
| break |
| return input_ids, type_ids, attn_mask, (ctx_start, ctx_end), enc["offset_mapping"][0].tolist() |
|
|
| if __name__ == "__main__": |
|
|
| REPO_ID = "ivabojic/deberta-v3-base_MultiSpanQA" |
| BASE = "microsoft/deberta-v3-base" |
| TOKENIZER = AutoTokenizer.from_pretrained(BASE) |
|
|
| model = SpanQualifier( |
| base_model_name=BASE, |
| max_span_gap=8, |
| dim2=64, |
| max_len=512, |
| vanilla=False, |
| force_answer=False, |
| ).to(DEVICE) |
| model.eval() |
|
|
| |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") |
| state = torch.load(ckpt_path, map_location="cpu") |
| state = state.get("model_state_dict", state) |
| missing, unexpected = model.load_state_dict(state, strict=False) |
|
|
| |
| question = "Who sang it's my party and i'll cry if i want to in the eighties?" |
| context = ( |
| "In 1981, a remake by British artists Dave Stewart and Barbara Gaskin " |
| "was a UK number one hit single for four weeks and was also a major hit " |
| "in Austria (#3), Germany (#3), the Netherlands (#20), New Zealand (#1), " |
| "South Africa (#3) and Switzerland (#6). The track reached #72 in the US. " |
| "This was the first version of the song to reach #1 in the UK. The video " |
| "for the Stewart/Gaskin version contained a cameo by Thomas Dolby as Johnny, " |
| "Judy being played by Gaskin in a blond wig." |
| ) |
|
|
| |
| input_ids, type_ids, attn_mask, ctx_range, offsets = build_single_features( |
| TOKENIZER, question, context, max_len=512, device=DEVICE |
| ) |
|
|
| with torch.no_grad(): |
| spans_rel = model( |
| input_ids=input_ids, |
| type_ids=type_ids, |
| mask_ids=attn_mask, |
| context_ranges=[ctx_range], |
| )[0] |
|
|
| |
| ctx_start, ctx_end = ctx_range |
| answers = [] |
| for beg_rel, end_rel in spans_rel: |
| beg_abs = ctx_start + beg_rel |
| end_abs = ctx_start + end_rel |
| s_char = offsets[beg_abs][0] |
| e_char = offsets[end_abs][1] |
| answers.append(context[s_char:e_char].strip()) |
|
|
| print("\nPredicted spans:") |
| for i, a in enumerate(answers, 1): |
| print(f"{i}. {a}") |
|
|