Upload 2 files
Browse files- inference_spanqualifier_hf.py +289 -0
- pytorch_model.bin +3 -0
inference_spanqualifier_hf.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers import AutoTokenizer, AutoModel
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 7 |
+
|
| 8 |
+
class MLP(nn.Module):
|
| 9 |
+
def __init__(self, dim0, dim1):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.linear1 = nn.Linear(dim0, dim0)
|
| 12 |
+
self.linear2 = nn.Linear(dim0, dim1)
|
| 13 |
+
self.activate = nn.ReLU()
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.linear2(self.activate(self.linear1(x)))
|
| 16 |
+
|
| 17 |
+
class SpanInteraction(nn.Module):
|
| 18 |
+
def __init__(self, dim2):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.conv = nn.Conv2d(dim2, dim2, kernel_size=(5,5), padding=(2,2))
|
| 21 |
+
def forward(self, hM):
|
| 22 |
+
hM = hM.permute(0,3,1,2)
|
| 23 |
+
hM = self.conv(hM)
|
| 24 |
+
return hM.permute(0,2,3,1)
|
| 25 |
+
|
| 26 |
+
class SpanEnumeration(nn.Module):
|
| 27 |
+
def __init__(self, dim1, dim2, max_len):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.s_mapping = nn.Linear(dim1, dim2)
|
| 30 |
+
self.e_mapping = nn.Linear(dim1, dim2) # <-- correct mapping for ends
|
| 31 |
+
self.pos_embedding = nn.Embedding(max_len, dim2)
|
| 32 |
+
self.layer_norm = nn.LayerNorm(dim2, eps=1e-12)
|
| 33 |
+
pos_id = []
|
| 34 |
+
for i in range(max_len):
|
| 35 |
+
for j in range(max_len):
|
| 36 |
+
pos_id.append(abs(j - i))
|
| 37 |
+
self.register_buffer("pos_id", torch.tensor(pos_id, dtype=torch.long))
|
| 38 |
+
self.dim2 = dim2
|
| 39 |
+
self.max_len = max_len
|
| 40 |
+
def forward(self, B_s, B_e):
|
| 41 |
+
bs, seq_len, _ = B_s.size()
|
| 42 |
+
pos_embedding = self.pos_embedding(self.pos_id).view(self.max_len, self.max_len, self.dim2)
|
| 43 |
+
pos_embedding = pos_embedding[:seq_len, :seq_len, :].reshape(seq_len, seq_len, self.dim2)
|
| 44 |
+
pos_embedding = pos_embedding.unsqueeze(0).expand(bs, seq_len, seq_len, self.dim2)
|
| 45 |
+
B_s = self.s_mapping(B_s)
|
| 46 |
+
B_e = self.s_mapping(B_e)
|
| 47 |
+
B_s_ex = B_s.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2)
|
| 48 |
+
B_e_ex = B_e.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2).transpose(1, 2)
|
| 49 |
+
N = B_s_ex + B_e_ex + pos_embedding
|
| 50 |
+
return self.layer_norm(N)
|
| 51 |
+
|
| 52 |
+
class BoundaryEnumeration(nn.Module):
|
| 53 |
+
def __init__(self, dim):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.s_boundary_enum = MLP(dim, dim)
|
| 56 |
+
self.e_boundary_enum = MLP(dim, dim)
|
| 57 |
+
def forward(self, H_c):
|
| 58 |
+
return self.s_boundary_enum(H_c), self.e_boundary_enum(H_c)
|
| 59 |
+
|
| 60 |
+
class SpanScoring(nn.Module):
|
| 61 |
+
def __init__(self, dim1, dim2, max_len, max_span_gap):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.mlp_scoring = MLP(dim2, 1)
|
| 64 |
+
self.mlp_cls = MLP(dim1, 1)
|
| 65 |
+
tri = []
|
| 66 |
+
for i in range(max_len):
|
| 67 |
+
for j in range(max_len):
|
| 68 |
+
tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
|
| 69 |
+
self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
|
| 70 |
+
def forward(self, M, H_cls, masks):
|
| 71 |
+
S = self.mlp_scoring(M).view(M.size(0), M.size(1), M.size(2))
|
| 72 |
+
qs = self.mlp_cls(H_cls) # (bs, 1)
|
| 73 |
+
bs, seq_len = S.size(0), S.size(1)
|
| 74 |
+
masks_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
|
| 75 |
+
masks_ex_t = masks_ex.transpose(1, 2)
|
| 76 |
+
masks_ex = masks_ex * masks_ex_t
|
| 77 |
+
tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
|
| 78 |
+
S = S - 10000.0 * (1 - masks_ex * tri)
|
| 79 |
+
return S, qs
|
| 80 |
+
|
| 81 |
+
class BoundaryAggregation(nn.Module):
|
| 82 |
+
def __init__(self, dim1, dim2, max_len, max_span_gap):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.span_enum_s = SpanEnumeration(dim1, dim2, max_len)
|
| 85 |
+
self.span_enum_e = SpanEnumeration(dim1, dim2, max_len)
|
| 86 |
+
self.span_scoring_s = SpanScoring(dim1, dim2, max_len, max_span_gap)
|
| 87 |
+
self.span_scoring_e = SpanScoring(dim1, dim2, max_len, max_span_gap)
|
| 88 |
+
self.W2_s = nn.Linear(dim1, dim1)
|
| 89 |
+
self.W2_e = nn.Linear(dim1, dim1)
|
| 90 |
+
self.span_interaction_s = SpanInteraction(dim2)
|
| 91 |
+
self.span_interaction_e = SpanInteraction(dim2)
|
| 92 |
+
def forward(self, hB_s, hB_e, H_cls, masks):
|
| 93 |
+
bs, seq_len, dim = hB_s.size()
|
| 94 |
+
M_s = self.span_enum_s(hB_s, hB_e)
|
| 95 |
+
M_s = self.span_interaction_s(M_s)
|
| 96 |
+
G_s, qs_s = self.span_scoring_s(M_s, H_cls, masks)
|
| 97 |
+
G_s_soft = torch.softmax(G_s, dim=-1)
|
| 98 |
+
B_s = torch.matmul(G_s_soft, self.W2_s(hB_s)).view(bs, seq_len, dim)
|
| 99 |
+
M_e = self.span_enum_e(hB_s, hB_e)
|
| 100 |
+
M_e = self.span_interaction_e(M_e)
|
| 101 |
+
G_e, qs_e = self.span_scoring_e(M_e, H_cls, masks)
|
| 102 |
+
G_e_soft = torch.softmax(G_e.transpose(-2, -1), dim=-1)
|
| 103 |
+
B_e = torch.matmul(G_e_soft, self.W2_e(hB_e)).view(bs, seq_len, dim)
|
| 104 |
+
return B_s, B_e, G_s, G_e, qs_s, qs_e
|
| 105 |
+
|
| 106 |
+
class SpanRepresentation(nn.Module):
|
| 107 |
+
def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.span_enum = SpanEnumeration(dim1, dim2, max_len)
|
| 110 |
+
self.span_interaction = SpanInteraction(dim2)
|
| 111 |
+
self.vanilla = vanilla
|
| 112 |
+
tri = []
|
| 113 |
+
for i in range(max_len):
|
| 114 |
+
for j in range(max_len):
|
| 115 |
+
tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
|
| 116 |
+
self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
|
| 117 |
+
def forward(self, B_s, B_e, masks):
|
| 118 |
+
M = self.span_enum(B_s, B_e)
|
| 119 |
+
bs, seq_len, _ = B_s.size()
|
| 120 |
+
masks_c_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
|
| 121 |
+
masks_c_ex_t = masks_c_ex.transpose(1, 2)
|
| 122 |
+
masks_c_ex = masks_c_ex * masks_c_ex_t
|
| 123 |
+
tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
|
| 124 |
+
M = M * (masks_c_ex * tri).unsqueeze(3)
|
| 125 |
+
if not self.vanilla:
|
| 126 |
+
M = self.span_interaction(M)
|
| 127 |
+
return M
|
| 128 |
+
|
| 129 |
+
class BoundaryRepresentation(nn.Module):
|
| 130 |
+
def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.boundary_enum = BoundaryEnumeration(dim1)
|
| 133 |
+
self.vanilla = vanilla
|
| 134 |
+
self.boundary_aggregation = BoundaryAggregation(dim1, dim2, max_len, max_span_gap)
|
| 135 |
+
def forward(self, H_c, H_cls, masks):
|
| 136 |
+
B_s, B_e = self.boundary_enum(H_c)
|
| 137 |
+
G_s = G_e = qs_s = qs_e = None
|
| 138 |
+
if not self.vanilla:
|
| 139 |
+
B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_aggregation(B_s, B_e, H_cls, masks)
|
| 140 |
+
return B_s, B_e, G_s, G_e, qs_s, qs_e
|
| 141 |
+
|
| 142 |
+
class SpanQualifier(nn.Module):
|
| 143 |
+
def __init__(self, base_model_name, max_span_gap=15, dim2=64, max_len=512, vanilla=False, force_answer=False):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.token_representation = AutoModel.from_pretrained(base_model_name)
|
| 146 |
+
dim1 = self.token_representation.config.hidden_size
|
| 147 |
+
self.boundary_representation = BoundaryRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
|
| 148 |
+
self.span_representation = SpanRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
|
| 149 |
+
self.span_scoring = SpanScoring(dim1, dim2, max_len, max_span_gap)
|
| 150 |
+
self.force_answer = force_answer
|
| 151 |
+
def forward(self, input_ids, type_ids, mask_ids, context_ranges):
|
| 152 |
+
outputs = self.token_representation(
|
| 153 |
+
input_ids=input_ids, attention_mask=mask_ids, token_type_ids=type_ids,
|
| 154 |
+
output_hidden_states=True, return_dict=True
|
| 155 |
+
)
|
| 156 |
+
sequence_output = outputs.hidden_states[-1] # (bs, L, H)
|
| 157 |
+
H_cls = sequence_output[:, 0, :]
|
| 158 |
+
H_c, masks = split_sequence_like(sequence_output, context_ranges)
|
| 159 |
+
B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_representation(H_c, H_cls, masks)
|
| 160 |
+
M = self.span_representation(B_s, B_e, masks)
|
| 161 |
+
S, qs_ext = self.span_scoring(M, H_cls, masks)
|
| 162 |
+
spans, _ = self.decoding_span_matrix(S, qs_ext)
|
| 163 |
+
return spans
|
| 164 |
+
def decoding_span_matrix(self, logits_matrix, threshold_p, spans_matrix_mask=None):
|
| 165 |
+
bs, seq_len, _ = logits_matrix.size()
|
| 166 |
+
if spans_matrix_mask is not None:
|
| 167 |
+
logits_matrix = logits_matrix - 10000.0 * spans_matrix_mask
|
| 168 |
+
logits_end = torch.softmax(logits_matrix, dim=2)
|
| 169 |
+
_, idx_best_end = torch.max(logits_end, dim=2)
|
| 170 |
+
idx_best_end = idx_best_end.cpu().tolist()
|
| 171 |
+
threshold_p = threshold_p.view(bs).cpu().tolist()
|
| 172 |
+
logits_beg = torch.softmax(logits_matrix, dim=1)
|
| 173 |
+
_, idx_best_beg = torch.max(logits_beg, dim=1)
|
| 174 |
+
idx_best_beg = idx_best_beg.cpu().tolist()
|
| 175 |
+
logits_matrix = logits_matrix.cpu().tolist()
|
| 176 |
+
spans = []
|
| 177 |
+
for b_i, (matrix, t_p) in enumerate(zip(logits_matrix, threshold_p)):
|
| 178 |
+
spans_item = []
|
| 179 |
+
max_logit, max_i, max_j = -10000, 0, 0
|
| 180 |
+
for i, row in enumerate(matrix):
|
| 181 |
+
for j, logit in enumerate(row):
|
| 182 |
+
if i <= j and idx_best_end[b_i][i] == j and idx_best_beg[b_i][j] == i:
|
| 183 |
+
if logit > t_p:
|
| 184 |
+
spans_item.append([i, j])
|
| 185 |
+
if logit > max_logit:
|
| 186 |
+
max_logit, max_i, max_j = logit, i, j
|
| 187 |
+
if len(spans_item) == 0 and self.force_answer:
|
| 188 |
+
spans_item.append([max_i, max_j])
|
| 189 |
+
spans.append(spans_item)
|
| 190 |
+
return spans, None
|
| 191 |
+
|
| 192 |
+
def split_sequence_like(sequence_output, context_ranges):
|
| 193 |
+
"""Packs context tokens to the front (like your split_sequence with useSep=False),
|
| 194 |
+
returns padded H_c and context masks."""
|
| 195 |
+
bs, L, H = sequence_output.size()
|
| 196 |
+
H_c_batch, masks_batch = [], []
|
| 197 |
+
for b in range(bs):
|
| 198 |
+
c_beg, c_end = context_ranges[b]
|
| 199 |
+
ctx = sequence_output[b, c_beg:c_end+1, :]
|
| 200 |
+
Lc = ctx.size(0)
|
| 201 |
+
H_c_pad = torch.zeros(L, H, device=sequence_output.device)
|
| 202 |
+
H_c_pad[:Lc] = ctx
|
| 203 |
+
mask = torch.zeros(L, device=sequence_output.device)
|
| 204 |
+
mask[:Lc] = 1
|
| 205 |
+
H_c_batch.append(H_c_pad)
|
| 206 |
+
masks_batch.append(mask)
|
| 207 |
+
return torch.stack(H_c_batch, 0), torch.stack(masks_batch, 0)
|
| 208 |
+
|
| 209 |
+
def build_single_features(tokenizer, question, context, max_len=512, device=DEVICE):
|
| 210 |
+
enc = tokenizer(question, context, return_tensors="pt",
|
| 211 |
+
return_offsets_mapping=True, truncation=True, max_length=max_len)
|
| 212 |
+
input_ids = enc["input_ids"].to(device)
|
| 213 |
+
attn_mask = enc["attention_mask"].to(device)
|
| 214 |
+
type_ids = enc.get("token_type_ids")
|
| 215 |
+
if type_ids is None:
|
| 216 |
+
type_ids = torch.zeros_like(input_ids)
|
| 217 |
+
type_ids = type_ids.to(device)
|
| 218 |
+
|
| 219 |
+
# get context token range using sequence_ids
|
| 220 |
+
seq_ids = tokenizer(question, context, return_offsets_mapping=True,
|
| 221 |
+
truncation=True, max_length=max_len).sequence_ids(0)
|
| 222 |
+
ctx_start = seq_ids.index(1) # first context token
|
| 223 |
+
ctx_end = len(seq_ids) - 1
|
| 224 |
+
for i in range(ctx_start, len(seq_ids)):
|
| 225 |
+
if seq_ids[i] is None:
|
| 226 |
+
ctx_end = i - 1
|
| 227 |
+
break
|
| 228 |
+
return input_ids, type_ids, attn_mask, (ctx_start, ctx_end), enc["offset_mapping"][0].tolist()
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
|
| 232 |
+
REPO_ID = "ivabojic/deberta-v3-base_MultiSpanQA"
|
| 233 |
+
BASE = "microsoft/deberta-v3-base"
|
| 234 |
+
TOKENIZER = AutoTokenizer.from_pretrained(BASE)
|
| 235 |
+
|
| 236 |
+
model = SpanQualifier(
|
| 237 |
+
base_model_name=BASE,
|
| 238 |
+
max_span_gap=8, # adjust per domain
|
| 239 |
+
dim2=64,
|
| 240 |
+
max_len=512,
|
| 241 |
+
vanilla=False,
|
| 242 |
+
force_answer=False,
|
| 243 |
+
).to(DEVICE)
|
| 244 |
+
model.eval()
|
| 245 |
+
|
| 246 |
+
# --- Download from the Hub ---
|
| 247 |
+
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
|
| 248 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 249 |
+
state = state.get("model_state_dict", state) # handle both formats
|
| 250 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 251 |
+
|
| 252 |
+
# --- Example Q & context ---
|
| 253 |
+
question = "Who sang it's my party and i'll cry if i want to in the eighties?"
|
| 254 |
+
context = (
|
| 255 |
+
"In 1981, a remake by British artists Dave Stewart and Barbara Gaskin "
|
| 256 |
+
"was a UK number one hit single for four weeks and was also a major hit "
|
| 257 |
+
"in Austria (#3), Germany (#3), the Netherlands (#20), New Zealand (#1), "
|
| 258 |
+
"South Africa (#3) and Switzerland (#6). The track reached #72 in the US. "
|
| 259 |
+
"This was the first version of the song to reach #1 in the UK. The video "
|
| 260 |
+
"for the Stewart/Gaskin version contained a cameo by Thomas Dolby as Johnny, "
|
| 261 |
+
"Judy being played by Gaskin in a blond wig."
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# --- Build features & run ---
|
| 265 |
+
input_ids, type_ids, attn_mask, ctx_range, offsets = build_single_features(
|
| 266 |
+
TOKENIZER, question, context, max_len=512, device=DEVICE
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
spans_rel = model(
|
| 271 |
+
input_ids=input_ids,
|
| 272 |
+
type_ids=type_ids,
|
| 273 |
+
mask_ids=attn_mask,
|
| 274 |
+
context_ranges=[ctx_range],
|
| 275 |
+
)[0]
|
| 276 |
+
|
| 277 |
+
# map relative spans back to text using absolute offsets
|
| 278 |
+
ctx_start, ctx_end = ctx_range
|
| 279 |
+
answers = []
|
| 280 |
+
for beg_rel, end_rel in spans_rel:
|
| 281 |
+
beg_abs = ctx_start + beg_rel
|
| 282 |
+
end_abs = ctx_start + end_rel
|
| 283 |
+
s_char = offsets[beg_abs][0]
|
| 284 |
+
e_char = offsets[end_abs][1]
|
| 285 |
+
answers.append(context[s_char:e_char].strip())
|
| 286 |
+
|
| 287 |
+
print("\nPredicted spans:")
|
| 288 |
+
for i, a in enumerate(answers, 1):
|
| 289 |
+
print(f"{i}. {a}")
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f1ba2f8500f0591e4364a32eaebb1363064b4aaf9877c2bb51d97e922ff1ce9
|
| 3 |
+
size 2277348140
|