import torch from torch import nn from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" class MLP(nn.Module): def __init__(self, dim0, dim1): super().__init__() self.linear1 = nn.Linear(dim0, dim0) self.linear2 = nn.Linear(dim0, dim1) self.activate = nn.ReLU() def forward(self, x): return self.linear2(self.activate(self.linear1(x))) class SpanInteraction(nn.Module): def __init__(self, dim2): super().__init__() self.conv = nn.Conv2d(dim2, dim2, kernel_size=(5,5), padding=(2,2)) def forward(self, hM): hM = hM.permute(0,3,1,2) hM = self.conv(hM) return hM.permute(0,2,3,1) class SpanEnumeration(nn.Module): def __init__(self, dim1, dim2, max_len): super().__init__() self.s_mapping = nn.Linear(dim1, dim2) self.e_mapping = nn.Linear(dim1, dim2) # <-- correct mapping for ends self.pos_embedding = nn.Embedding(max_len, dim2) self.layer_norm = nn.LayerNorm(dim2, eps=1e-12) pos_id = [] for i in range(max_len): for j in range(max_len): pos_id.append(abs(j - i)) self.register_buffer("pos_id", torch.tensor(pos_id, dtype=torch.long)) self.dim2 = dim2 self.max_len = max_len def forward(self, B_s, B_e): bs, seq_len, _ = B_s.size() pos_embedding = self.pos_embedding(self.pos_id).view(self.max_len, self.max_len, self.dim2) pos_embedding = pos_embedding[:seq_len, :seq_len, :].reshape(seq_len, seq_len, self.dim2) pos_embedding = pos_embedding.unsqueeze(0).expand(bs, seq_len, seq_len, self.dim2) B_s = self.s_mapping(B_s) B_e = self.s_mapping(B_e) B_s_ex = B_s.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2) B_e_ex = B_e.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2).transpose(1, 2) N = B_s_ex + B_e_ex + pos_embedding return self.layer_norm(N) class BoundaryEnumeration(nn.Module): def __init__(self, dim): super().__init__() self.s_boundary_enum = MLP(dim, dim) self.e_boundary_enum = MLP(dim, dim) def forward(self, H_c): return self.s_boundary_enum(H_c), self.e_boundary_enum(H_c) class SpanScoring(nn.Module): def __init__(self, dim1, dim2, max_len, max_span_gap): super().__init__() self.mlp_scoring = MLP(dim2, 1) self.mlp_cls = MLP(dim1, 1) tri = [] for i in range(max_len): for j in range(max_len): tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0) self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len)) def forward(self, M, H_cls, masks): S = self.mlp_scoring(M).view(M.size(0), M.size(1), M.size(2)) qs = self.mlp_cls(H_cls) # (bs, 1) bs, seq_len = S.size(0), S.size(1) masks_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len) masks_ex_t = masks_ex.transpose(1, 2) masks_ex = masks_ex * masks_ex_t tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len) S = S - 10000.0 * (1 - masks_ex * tri) return S, qs class BoundaryAggregation(nn.Module): def __init__(self, dim1, dim2, max_len, max_span_gap): super().__init__() self.span_enum_s = SpanEnumeration(dim1, dim2, max_len) self.span_enum_e = SpanEnumeration(dim1, dim2, max_len) self.span_scoring_s = SpanScoring(dim1, dim2, max_len, max_span_gap) self.span_scoring_e = SpanScoring(dim1, dim2, max_len, max_span_gap) self.W2_s = nn.Linear(dim1, dim1) self.W2_e = nn.Linear(dim1, dim1) self.span_interaction_s = SpanInteraction(dim2) self.span_interaction_e = SpanInteraction(dim2) def forward(self, hB_s, hB_e, H_cls, masks): bs, seq_len, dim = hB_s.size() M_s = self.span_enum_s(hB_s, hB_e) M_s = self.span_interaction_s(M_s) G_s, qs_s = self.span_scoring_s(M_s, H_cls, masks) G_s_soft = torch.softmax(G_s, dim=-1) B_s = torch.matmul(G_s_soft, self.W2_s(hB_s)).view(bs, seq_len, dim) M_e = self.span_enum_e(hB_s, hB_e) M_e = self.span_interaction_e(M_e) G_e, qs_e = self.span_scoring_e(M_e, H_cls, masks) G_e_soft = torch.softmax(G_e.transpose(-2, -1), dim=-1) B_e = torch.matmul(G_e_soft, self.W2_e(hB_e)).view(bs, seq_len, dim) return B_s, B_e, G_s, G_e, qs_s, qs_e class SpanRepresentation(nn.Module): def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False): super().__init__() self.span_enum = SpanEnumeration(dim1, dim2, max_len) self.span_interaction = SpanInteraction(dim2) self.vanilla = vanilla tri = [] for i in range(max_len): for j in range(max_len): tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0) self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len)) def forward(self, B_s, B_e, masks): M = self.span_enum(B_s, B_e) bs, seq_len, _ = B_s.size() masks_c_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len) masks_c_ex_t = masks_c_ex.transpose(1, 2) masks_c_ex = masks_c_ex * masks_c_ex_t tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len) M = M * (masks_c_ex * tri).unsqueeze(3) if not self.vanilla: M = self.span_interaction(M) return M class BoundaryRepresentation(nn.Module): def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False): super().__init__() self.boundary_enum = BoundaryEnumeration(dim1) self.vanilla = vanilla self.boundary_aggregation = BoundaryAggregation(dim1, dim2, max_len, max_span_gap) def forward(self, H_c, H_cls, masks): B_s, B_e = self.boundary_enum(H_c) G_s = G_e = qs_s = qs_e = None if not self.vanilla: B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_aggregation(B_s, B_e, H_cls, masks) return B_s, B_e, G_s, G_e, qs_s, qs_e class SpanQualifier(nn.Module): def __init__(self, base_model_name, max_span_gap=15, dim2=64, max_len=512, vanilla=False, force_answer=False): super().__init__() self.token_representation = AutoModel.from_pretrained(base_model_name) dim1 = self.token_representation.config.hidden_size self.boundary_representation = BoundaryRepresentation(dim1, dim2, max_len, max_span_gap, vanilla) self.span_representation = SpanRepresentation(dim1, dim2, max_len, max_span_gap, vanilla) self.span_scoring = SpanScoring(dim1, dim2, max_len, max_span_gap) self.force_answer = force_answer def forward(self, input_ids, type_ids, mask_ids, context_ranges): outputs = self.token_representation( input_ids=input_ids, attention_mask=mask_ids, token_type_ids=type_ids, output_hidden_states=True, return_dict=True ) sequence_output = outputs.hidden_states[-1] # (bs, L, H) H_cls = sequence_output[:, 0, :] H_c, masks = split_sequence_like(sequence_output, context_ranges) B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_representation(H_c, H_cls, masks) M = self.span_representation(B_s, B_e, masks) S, qs_ext = self.span_scoring(M, H_cls, masks) spans, _ = self.decoding_span_matrix(S, qs_ext) return spans def decoding_span_matrix(self, logits_matrix, threshold_p, spans_matrix_mask=None): bs, seq_len, _ = logits_matrix.size() if spans_matrix_mask is not None: logits_matrix = logits_matrix - 10000.0 * spans_matrix_mask logits_end = torch.softmax(logits_matrix, dim=2) _, idx_best_end = torch.max(logits_end, dim=2) idx_best_end = idx_best_end.cpu().tolist() threshold_p = threshold_p.view(bs).cpu().tolist() logits_beg = torch.softmax(logits_matrix, dim=1) _, idx_best_beg = torch.max(logits_beg, dim=1) idx_best_beg = idx_best_beg.cpu().tolist() logits_matrix = logits_matrix.cpu().tolist() spans = [] for b_i, (matrix, t_p) in enumerate(zip(logits_matrix, threshold_p)): spans_item = [] max_logit, max_i, max_j = -10000, 0, 0 for i, row in enumerate(matrix): for j, logit in enumerate(row): if i <= j and idx_best_end[b_i][i] == j and idx_best_beg[b_i][j] == i: if logit > t_p: spans_item.append([i, j]) if logit > max_logit: max_logit, max_i, max_j = logit, i, j if len(spans_item) == 0 and self.force_answer: spans_item.append([max_i, max_j]) spans.append(spans_item) return spans, None def split_sequence_like(sequence_output, context_ranges): """Packs context tokens to the front (like your split_sequence with useSep=False), returns padded H_c and context masks.""" bs, L, H = sequence_output.size() H_c_batch, masks_batch = [], [] for b in range(bs): c_beg, c_end = context_ranges[b] ctx = sequence_output[b, c_beg:c_end+1, :] Lc = ctx.size(0) H_c_pad = torch.zeros(L, H, device=sequence_output.device) H_c_pad[:Lc] = ctx mask = torch.zeros(L, device=sequence_output.device) mask[:Lc] = 1 H_c_batch.append(H_c_pad) masks_batch.append(mask) return torch.stack(H_c_batch, 0), torch.stack(masks_batch, 0) def build_single_features(tokenizer, question, context, max_len=512, device=DEVICE): enc = tokenizer(question, context, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=max_len) input_ids = enc["input_ids"].to(device) attn_mask = enc["attention_mask"].to(device) type_ids = enc.get("token_type_ids") if type_ids is None: type_ids = torch.zeros_like(input_ids) type_ids = type_ids.to(device) # get context token range using sequence_ids seq_ids = tokenizer(question, context, return_offsets_mapping=True, truncation=True, max_length=max_len).sequence_ids(0) ctx_start = seq_ids.index(1) # first context token ctx_end = len(seq_ids) - 1 for i in range(ctx_start, len(seq_ids)): if seq_ids[i] is None: ctx_end = i - 1 break return input_ids, type_ids, attn_mask, (ctx_start, ctx_end), enc["offset_mapping"][0].tolist() if __name__ == "__main__": REPO_ID = "ivabojic/deberta-v3-base_MultiSpanQA" BASE = "microsoft/deberta-v3-base" TOKENIZER = AutoTokenizer.from_pretrained(BASE) model = SpanQualifier( base_model_name=BASE, max_span_gap=8, # adjust per domain dim2=64, max_len=512, vanilla=False, force_answer=False, ).to(DEVICE) model.eval() # --- Download from the Hub --- ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") state = torch.load(ckpt_path, map_location="cpu") state = state.get("model_state_dict", state) # handle both formats missing, unexpected = model.load_state_dict(state, strict=False) # --- Example Q & context --- question = "Who sang it's my party and i'll cry if i want to in the eighties?" context = ( "In 1981, a remake by British artists Dave Stewart and Barbara Gaskin " "was a UK number one hit single for four weeks and was also a major hit " "in Austria (#3), Germany (#3), the Netherlands (#20), New Zealand (#1), " "South Africa (#3) and Switzerland (#6). The track reached #72 in the US. " "This was the first version of the song to reach #1 in the UK. The video " "for the Stewart/Gaskin version contained a cameo by Thomas Dolby as Johnny, " "Judy being played by Gaskin in a blond wig." ) # --- Build features & run --- input_ids, type_ids, attn_mask, ctx_range, offsets = build_single_features( TOKENIZER, question, context, max_len=512, device=DEVICE ) with torch.no_grad(): spans_rel = model( input_ids=input_ids, type_ids=type_ids, mask_ids=attn_mask, context_ranges=[ctx_range], )[0] # map relative spans back to text using absolute offsets ctx_start, ctx_end = ctx_range answers = [] for beg_rel, end_rel in spans_rel: beg_abs = ctx_start + beg_rel end_abs = ctx_start + end_rel s_char = offsets[beg_abs][0] e_char = offsets[end_abs][1] answers.append(context[s_char:e_char].strip()) print("\nPredicted spans:") for i, a in enumerate(answers, 1): print(f"{i}. {a}")