ivabojic commited on
Commit
e30b0b5
·
verified ·
1 Parent(s): 200f8f5

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference_spanqualifier_hf.py +289 -0
  2. pytorch_model.bin +3 -0
inference_spanqualifier_hf.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+
8
+ class MLP(nn.Module):
9
+ def __init__(self, dim0, dim1):
10
+ super().__init__()
11
+ self.linear1 = nn.Linear(dim0, dim0)
12
+ self.linear2 = nn.Linear(dim0, dim1)
13
+ self.activate = nn.ReLU()
14
+ def forward(self, x):
15
+ return self.linear2(self.activate(self.linear1(x)))
16
+
17
+ class SpanInteraction(nn.Module):
18
+ def __init__(self, dim2):
19
+ super().__init__()
20
+ self.conv = nn.Conv2d(dim2, dim2, kernel_size=(5,5), padding=(2,2))
21
+ def forward(self, hM):
22
+ hM = hM.permute(0,3,1,2)
23
+ hM = self.conv(hM)
24
+ return hM.permute(0,2,3,1)
25
+
26
+ class SpanEnumeration(nn.Module):
27
+ def __init__(self, dim1, dim2, max_len):
28
+ super().__init__()
29
+ self.s_mapping = nn.Linear(dim1, dim2)
30
+ self.e_mapping = nn.Linear(dim1, dim2) # <-- correct mapping for ends
31
+ self.pos_embedding = nn.Embedding(max_len, dim2)
32
+ self.layer_norm = nn.LayerNorm(dim2, eps=1e-12)
33
+ pos_id = []
34
+ for i in range(max_len):
35
+ for j in range(max_len):
36
+ pos_id.append(abs(j - i))
37
+ self.register_buffer("pos_id", torch.tensor(pos_id, dtype=torch.long))
38
+ self.dim2 = dim2
39
+ self.max_len = max_len
40
+ def forward(self, B_s, B_e):
41
+ bs, seq_len, _ = B_s.size()
42
+ pos_embedding = self.pos_embedding(self.pos_id).view(self.max_len, self.max_len, self.dim2)
43
+ pos_embedding = pos_embedding[:seq_len, :seq_len, :].reshape(seq_len, seq_len, self.dim2)
44
+ pos_embedding = pos_embedding.unsqueeze(0).expand(bs, seq_len, seq_len, self.dim2)
45
+ B_s = self.s_mapping(B_s)
46
+ B_e = self.s_mapping(B_e)
47
+ B_s_ex = B_s.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2)
48
+ B_e_ex = B_e.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2).transpose(1, 2)
49
+ N = B_s_ex + B_e_ex + pos_embedding
50
+ return self.layer_norm(N)
51
+
52
+ class BoundaryEnumeration(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.s_boundary_enum = MLP(dim, dim)
56
+ self.e_boundary_enum = MLP(dim, dim)
57
+ def forward(self, H_c):
58
+ return self.s_boundary_enum(H_c), self.e_boundary_enum(H_c)
59
+
60
+ class SpanScoring(nn.Module):
61
+ def __init__(self, dim1, dim2, max_len, max_span_gap):
62
+ super().__init__()
63
+ self.mlp_scoring = MLP(dim2, 1)
64
+ self.mlp_cls = MLP(dim1, 1)
65
+ tri = []
66
+ for i in range(max_len):
67
+ for j in range(max_len):
68
+ tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
69
+ self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
70
+ def forward(self, M, H_cls, masks):
71
+ S = self.mlp_scoring(M).view(M.size(0), M.size(1), M.size(2))
72
+ qs = self.mlp_cls(H_cls) # (bs, 1)
73
+ bs, seq_len = S.size(0), S.size(1)
74
+ masks_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
75
+ masks_ex_t = masks_ex.transpose(1, 2)
76
+ masks_ex = masks_ex * masks_ex_t
77
+ tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
78
+ S = S - 10000.0 * (1 - masks_ex * tri)
79
+ return S, qs
80
+
81
+ class BoundaryAggregation(nn.Module):
82
+ def __init__(self, dim1, dim2, max_len, max_span_gap):
83
+ super().__init__()
84
+ self.span_enum_s = SpanEnumeration(dim1, dim2, max_len)
85
+ self.span_enum_e = SpanEnumeration(dim1, dim2, max_len)
86
+ self.span_scoring_s = SpanScoring(dim1, dim2, max_len, max_span_gap)
87
+ self.span_scoring_e = SpanScoring(dim1, dim2, max_len, max_span_gap)
88
+ self.W2_s = nn.Linear(dim1, dim1)
89
+ self.W2_e = nn.Linear(dim1, dim1)
90
+ self.span_interaction_s = SpanInteraction(dim2)
91
+ self.span_interaction_e = SpanInteraction(dim2)
92
+ def forward(self, hB_s, hB_e, H_cls, masks):
93
+ bs, seq_len, dim = hB_s.size()
94
+ M_s = self.span_enum_s(hB_s, hB_e)
95
+ M_s = self.span_interaction_s(M_s)
96
+ G_s, qs_s = self.span_scoring_s(M_s, H_cls, masks)
97
+ G_s_soft = torch.softmax(G_s, dim=-1)
98
+ B_s = torch.matmul(G_s_soft, self.W2_s(hB_s)).view(bs, seq_len, dim)
99
+ M_e = self.span_enum_e(hB_s, hB_e)
100
+ M_e = self.span_interaction_e(M_e)
101
+ G_e, qs_e = self.span_scoring_e(M_e, H_cls, masks)
102
+ G_e_soft = torch.softmax(G_e.transpose(-2, -1), dim=-1)
103
+ B_e = torch.matmul(G_e_soft, self.W2_e(hB_e)).view(bs, seq_len, dim)
104
+ return B_s, B_e, G_s, G_e, qs_s, qs_e
105
+
106
+ class SpanRepresentation(nn.Module):
107
+ def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
108
+ super().__init__()
109
+ self.span_enum = SpanEnumeration(dim1, dim2, max_len)
110
+ self.span_interaction = SpanInteraction(dim2)
111
+ self.vanilla = vanilla
112
+ tri = []
113
+ for i in range(max_len):
114
+ for j in range(max_len):
115
+ tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
116
+ self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
117
+ def forward(self, B_s, B_e, masks):
118
+ M = self.span_enum(B_s, B_e)
119
+ bs, seq_len, _ = B_s.size()
120
+ masks_c_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
121
+ masks_c_ex_t = masks_c_ex.transpose(1, 2)
122
+ masks_c_ex = masks_c_ex * masks_c_ex_t
123
+ tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
124
+ M = M * (masks_c_ex * tri).unsqueeze(3)
125
+ if not self.vanilla:
126
+ M = self.span_interaction(M)
127
+ return M
128
+
129
+ class BoundaryRepresentation(nn.Module):
130
+ def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
131
+ super().__init__()
132
+ self.boundary_enum = BoundaryEnumeration(dim1)
133
+ self.vanilla = vanilla
134
+ self.boundary_aggregation = BoundaryAggregation(dim1, dim2, max_len, max_span_gap)
135
+ def forward(self, H_c, H_cls, masks):
136
+ B_s, B_e = self.boundary_enum(H_c)
137
+ G_s = G_e = qs_s = qs_e = None
138
+ if not self.vanilla:
139
+ B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_aggregation(B_s, B_e, H_cls, masks)
140
+ return B_s, B_e, G_s, G_e, qs_s, qs_e
141
+
142
+ class SpanQualifier(nn.Module):
143
+ def __init__(self, base_model_name, max_span_gap=15, dim2=64, max_len=512, vanilla=False, force_answer=False):
144
+ super().__init__()
145
+ self.token_representation = AutoModel.from_pretrained(base_model_name)
146
+ dim1 = self.token_representation.config.hidden_size
147
+ self.boundary_representation = BoundaryRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
148
+ self.span_representation = SpanRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
149
+ self.span_scoring = SpanScoring(dim1, dim2, max_len, max_span_gap)
150
+ self.force_answer = force_answer
151
+ def forward(self, input_ids, type_ids, mask_ids, context_ranges):
152
+ outputs = self.token_representation(
153
+ input_ids=input_ids, attention_mask=mask_ids, token_type_ids=type_ids,
154
+ output_hidden_states=True, return_dict=True
155
+ )
156
+ sequence_output = outputs.hidden_states[-1] # (bs, L, H)
157
+ H_cls = sequence_output[:, 0, :]
158
+ H_c, masks = split_sequence_like(sequence_output, context_ranges)
159
+ B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_representation(H_c, H_cls, masks)
160
+ M = self.span_representation(B_s, B_e, masks)
161
+ S, qs_ext = self.span_scoring(M, H_cls, masks)
162
+ spans, _ = self.decoding_span_matrix(S, qs_ext)
163
+ return spans
164
+ def decoding_span_matrix(self, logits_matrix, threshold_p, spans_matrix_mask=None):
165
+ bs, seq_len, _ = logits_matrix.size()
166
+ if spans_matrix_mask is not None:
167
+ logits_matrix = logits_matrix - 10000.0 * spans_matrix_mask
168
+ logits_end = torch.softmax(logits_matrix, dim=2)
169
+ _, idx_best_end = torch.max(logits_end, dim=2)
170
+ idx_best_end = idx_best_end.cpu().tolist()
171
+ threshold_p = threshold_p.view(bs).cpu().tolist()
172
+ logits_beg = torch.softmax(logits_matrix, dim=1)
173
+ _, idx_best_beg = torch.max(logits_beg, dim=1)
174
+ idx_best_beg = idx_best_beg.cpu().tolist()
175
+ logits_matrix = logits_matrix.cpu().tolist()
176
+ spans = []
177
+ for b_i, (matrix, t_p) in enumerate(zip(logits_matrix, threshold_p)):
178
+ spans_item = []
179
+ max_logit, max_i, max_j = -10000, 0, 0
180
+ for i, row in enumerate(matrix):
181
+ for j, logit in enumerate(row):
182
+ if i <= j and idx_best_end[b_i][i] == j and idx_best_beg[b_i][j] == i:
183
+ if logit > t_p:
184
+ spans_item.append([i, j])
185
+ if logit > max_logit:
186
+ max_logit, max_i, max_j = logit, i, j
187
+ if len(spans_item) == 0 and self.force_answer:
188
+ spans_item.append([max_i, max_j])
189
+ spans.append(spans_item)
190
+ return spans, None
191
+
192
+ def split_sequence_like(sequence_output, context_ranges):
193
+ """Packs context tokens to the front (like your split_sequence with useSep=False),
194
+ returns padded H_c and context masks."""
195
+ bs, L, H = sequence_output.size()
196
+ H_c_batch, masks_batch = [], []
197
+ for b in range(bs):
198
+ c_beg, c_end = context_ranges[b]
199
+ ctx = sequence_output[b, c_beg:c_end+1, :]
200
+ Lc = ctx.size(0)
201
+ H_c_pad = torch.zeros(L, H, device=sequence_output.device)
202
+ H_c_pad[:Lc] = ctx
203
+ mask = torch.zeros(L, device=sequence_output.device)
204
+ mask[:Lc] = 1
205
+ H_c_batch.append(H_c_pad)
206
+ masks_batch.append(mask)
207
+ return torch.stack(H_c_batch, 0), torch.stack(masks_batch, 0)
208
+
209
+ def build_single_features(tokenizer, question, context, max_len=512, device=DEVICE):
210
+ enc = tokenizer(question, context, return_tensors="pt",
211
+ return_offsets_mapping=True, truncation=True, max_length=max_len)
212
+ input_ids = enc["input_ids"].to(device)
213
+ attn_mask = enc["attention_mask"].to(device)
214
+ type_ids = enc.get("token_type_ids")
215
+ if type_ids is None:
216
+ type_ids = torch.zeros_like(input_ids)
217
+ type_ids = type_ids.to(device)
218
+
219
+ # get context token range using sequence_ids
220
+ seq_ids = tokenizer(question, context, return_offsets_mapping=True,
221
+ truncation=True, max_length=max_len).sequence_ids(0)
222
+ ctx_start = seq_ids.index(1) # first context token
223
+ ctx_end = len(seq_ids) - 1
224
+ for i in range(ctx_start, len(seq_ids)):
225
+ if seq_ids[i] is None:
226
+ ctx_end = i - 1
227
+ break
228
+ return input_ids, type_ids, attn_mask, (ctx_start, ctx_end), enc["offset_mapping"][0].tolist()
229
+
230
+ if __name__ == "__main__":
231
+
232
+ REPO_ID = "ivabojic/deberta-v3-base_MultiSpanQA"
233
+ BASE = "microsoft/deberta-v3-base"
234
+ TOKENIZER = AutoTokenizer.from_pretrained(BASE)
235
+
236
+ model = SpanQualifier(
237
+ base_model_name=BASE,
238
+ max_span_gap=8, # adjust per domain
239
+ dim2=64,
240
+ max_len=512,
241
+ vanilla=False,
242
+ force_answer=False,
243
+ ).to(DEVICE)
244
+ model.eval()
245
+
246
+ # --- Download from the Hub ---
247
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
248
+ state = torch.load(ckpt_path, map_location="cpu")
249
+ state = state.get("model_state_dict", state) # handle both formats
250
+ missing, unexpected = model.load_state_dict(state, strict=False)
251
+
252
+ # --- Example Q & context ---
253
+ question = "Who sang it's my party and i'll cry if i want to in the eighties?"
254
+ context = (
255
+ "In 1981, a remake by British artists Dave Stewart and Barbara Gaskin "
256
+ "was a UK number one hit single for four weeks and was also a major hit "
257
+ "in Austria (#3), Germany (#3), the Netherlands (#20), New Zealand (#1), "
258
+ "South Africa (#3) and Switzerland (#6). The track reached #72 in the US. "
259
+ "This was the first version of the song to reach #1 in the UK. The video "
260
+ "for the Stewart/Gaskin version contained a cameo by Thomas Dolby as Johnny, "
261
+ "Judy being played by Gaskin in a blond wig."
262
+ )
263
+
264
+ # --- Build features & run ---
265
+ input_ids, type_ids, attn_mask, ctx_range, offsets = build_single_features(
266
+ TOKENIZER, question, context, max_len=512, device=DEVICE
267
+ )
268
+
269
+ with torch.no_grad():
270
+ spans_rel = model(
271
+ input_ids=input_ids,
272
+ type_ids=type_ids,
273
+ mask_ids=attn_mask,
274
+ context_ranges=[ctx_range],
275
+ )[0]
276
+
277
+ # map relative spans back to text using absolute offsets
278
+ ctx_start, ctx_end = ctx_range
279
+ answers = []
280
+ for beg_rel, end_rel in spans_rel:
281
+ beg_abs = ctx_start + beg_rel
282
+ end_abs = ctx_start + end_rel
283
+ s_char = offsets[beg_abs][0]
284
+ e_char = offsets[end_abs][1]
285
+ answers.append(context[s_char:e_char].strip())
286
+
287
+ print("\nPredicted spans:")
288
+ for i, a in enumerate(answers, 1):
289
+ print(f"{i}. {a}")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f1ba2f8500f0591e4364a32eaebb1363064b4aaf9877c2bb51d97e922ff1ce9
3
+ size 2277348140