deberta-v3-base_MultiSpanQA / inference_spanqualifier_hf.py
ivabojic's picture
Upload 2 files
e30b0b5 verified
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
class MLP(nn.Module):
def __init__(self, dim0, dim1):
super().__init__()
self.linear1 = nn.Linear(dim0, dim0)
self.linear2 = nn.Linear(dim0, dim1)
self.activate = nn.ReLU()
def forward(self, x):
return self.linear2(self.activate(self.linear1(x)))
class SpanInteraction(nn.Module):
def __init__(self, dim2):
super().__init__()
self.conv = nn.Conv2d(dim2, dim2, kernel_size=(5,5), padding=(2,2))
def forward(self, hM):
hM = hM.permute(0,3,1,2)
hM = self.conv(hM)
return hM.permute(0,2,3,1)
class SpanEnumeration(nn.Module):
def __init__(self, dim1, dim2, max_len):
super().__init__()
self.s_mapping = nn.Linear(dim1, dim2)
self.e_mapping = nn.Linear(dim1, dim2) # <-- correct mapping for ends
self.pos_embedding = nn.Embedding(max_len, dim2)
self.layer_norm = nn.LayerNorm(dim2, eps=1e-12)
pos_id = []
for i in range(max_len):
for j in range(max_len):
pos_id.append(abs(j - i))
self.register_buffer("pos_id", torch.tensor(pos_id, dtype=torch.long))
self.dim2 = dim2
self.max_len = max_len
def forward(self, B_s, B_e):
bs, seq_len, _ = B_s.size()
pos_embedding = self.pos_embedding(self.pos_id).view(self.max_len, self.max_len, self.dim2)
pos_embedding = pos_embedding[:seq_len, :seq_len, :].reshape(seq_len, seq_len, self.dim2)
pos_embedding = pos_embedding.unsqueeze(0).expand(bs, seq_len, seq_len, self.dim2)
B_s = self.s_mapping(B_s)
B_e = self.s_mapping(B_e)
B_s_ex = B_s.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2)
B_e_ex = B_e.unsqueeze(2).expand(bs, seq_len, seq_len, self.dim2).transpose(1, 2)
N = B_s_ex + B_e_ex + pos_embedding
return self.layer_norm(N)
class BoundaryEnumeration(nn.Module):
def __init__(self, dim):
super().__init__()
self.s_boundary_enum = MLP(dim, dim)
self.e_boundary_enum = MLP(dim, dim)
def forward(self, H_c):
return self.s_boundary_enum(H_c), self.e_boundary_enum(H_c)
class SpanScoring(nn.Module):
def __init__(self, dim1, dim2, max_len, max_span_gap):
super().__init__()
self.mlp_scoring = MLP(dim2, 1)
self.mlp_cls = MLP(dim1, 1)
tri = []
for i in range(max_len):
for j in range(max_len):
tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
def forward(self, M, H_cls, masks):
S = self.mlp_scoring(M).view(M.size(0), M.size(1), M.size(2))
qs = self.mlp_cls(H_cls) # (bs, 1)
bs, seq_len = S.size(0), S.size(1)
masks_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
masks_ex_t = masks_ex.transpose(1, 2)
masks_ex = masks_ex * masks_ex_t
tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
S = S - 10000.0 * (1 - masks_ex * tri)
return S, qs
class BoundaryAggregation(nn.Module):
def __init__(self, dim1, dim2, max_len, max_span_gap):
super().__init__()
self.span_enum_s = SpanEnumeration(dim1, dim2, max_len)
self.span_enum_e = SpanEnumeration(dim1, dim2, max_len)
self.span_scoring_s = SpanScoring(dim1, dim2, max_len, max_span_gap)
self.span_scoring_e = SpanScoring(dim1, dim2, max_len, max_span_gap)
self.W2_s = nn.Linear(dim1, dim1)
self.W2_e = nn.Linear(dim1, dim1)
self.span_interaction_s = SpanInteraction(dim2)
self.span_interaction_e = SpanInteraction(dim2)
def forward(self, hB_s, hB_e, H_cls, masks):
bs, seq_len, dim = hB_s.size()
M_s = self.span_enum_s(hB_s, hB_e)
M_s = self.span_interaction_s(M_s)
G_s, qs_s = self.span_scoring_s(M_s, H_cls, masks)
G_s_soft = torch.softmax(G_s, dim=-1)
B_s = torch.matmul(G_s_soft, self.W2_s(hB_s)).view(bs, seq_len, dim)
M_e = self.span_enum_e(hB_s, hB_e)
M_e = self.span_interaction_e(M_e)
G_e, qs_e = self.span_scoring_e(M_e, H_cls, masks)
G_e_soft = torch.softmax(G_e.transpose(-2, -1), dim=-1)
B_e = torch.matmul(G_e_soft, self.W2_e(hB_e)).view(bs, seq_len, dim)
return B_s, B_e, G_s, G_e, qs_s, qs_e
class SpanRepresentation(nn.Module):
def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
super().__init__()
self.span_enum = SpanEnumeration(dim1, dim2, max_len)
self.span_interaction = SpanInteraction(dim2)
self.vanilla = vanilla
tri = []
for i in range(max_len):
for j in range(max_len):
tri.append(1 if (i <= j and (j - i) <= max_span_gap) else 0)
self.register_buffer("masks_triangle", torch.tensor(tri, dtype=torch.float).view(max_len, max_len))
def forward(self, B_s, B_e, masks):
M = self.span_enum(B_s, B_e)
bs, seq_len, _ = B_s.size()
masks_c_ex = masks.unsqueeze(1).expand(bs, seq_len, seq_len)
masks_c_ex_t = masks_c_ex.transpose(1, 2)
masks_c_ex = masks_c_ex * masks_c_ex_t
tri = self.masks_triangle[:seq_len, :seq_len].unsqueeze(0).expand(bs, seq_len, seq_len)
M = M * (masks_c_ex * tri).unsqueeze(3)
if not self.vanilla:
M = self.span_interaction(M)
return M
class BoundaryRepresentation(nn.Module):
def __init__(self, dim1, dim2, max_len, max_span_gap, vanilla=False):
super().__init__()
self.boundary_enum = BoundaryEnumeration(dim1)
self.vanilla = vanilla
self.boundary_aggregation = BoundaryAggregation(dim1, dim2, max_len, max_span_gap)
def forward(self, H_c, H_cls, masks):
B_s, B_e = self.boundary_enum(H_c)
G_s = G_e = qs_s = qs_e = None
if not self.vanilla:
B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_aggregation(B_s, B_e, H_cls, masks)
return B_s, B_e, G_s, G_e, qs_s, qs_e
class SpanQualifier(nn.Module):
def __init__(self, base_model_name, max_span_gap=15, dim2=64, max_len=512, vanilla=False, force_answer=False):
super().__init__()
self.token_representation = AutoModel.from_pretrained(base_model_name)
dim1 = self.token_representation.config.hidden_size
self.boundary_representation = BoundaryRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
self.span_representation = SpanRepresentation(dim1, dim2, max_len, max_span_gap, vanilla)
self.span_scoring = SpanScoring(dim1, dim2, max_len, max_span_gap)
self.force_answer = force_answer
def forward(self, input_ids, type_ids, mask_ids, context_ranges):
outputs = self.token_representation(
input_ids=input_ids, attention_mask=mask_ids, token_type_ids=type_ids,
output_hidden_states=True, return_dict=True
)
sequence_output = outputs.hidden_states[-1] # (bs, L, H)
H_cls = sequence_output[:, 0, :]
H_c, masks = split_sequence_like(sequence_output, context_ranges)
B_s, B_e, G_s, G_e, qs_s, qs_e = self.boundary_representation(H_c, H_cls, masks)
M = self.span_representation(B_s, B_e, masks)
S, qs_ext = self.span_scoring(M, H_cls, masks)
spans, _ = self.decoding_span_matrix(S, qs_ext)
return spans
def decoding_span_matrix(self, logits_matrix, threshold_p, spans_matrix_mask=None):
bs, seq_len, _ = logits_matrix.size()
if spans_matrix_mask is not None:
logits_matrix = logits_matrix - 10000.0 * spans_matrix_mask
logits_end = torch.softmax(logits_matrix, dim=2)
_, idx_best_end = torch.max(logits_end, dim=2)
idx_best_end = idx_best_end.cpu().tolist()
threshold_p = threshold_p.view(bs).cpu().tolist()
logits_beg = torch.softmax(logits_matrix, dim=1)
_, idx_best_beg = torch.max(logits_beg, dim=1)
idx_best_beg = idx_best_beg.cpu().tolist()
logits_matrix = logits_matrix.cpu().tolist()
spans = []
for b_i, (matrix, t_p) in enumerate(zip(logits_matrix, threshold_p)):
spans_item = []
max_logit, max_i, max_j = -10000, 0, 0
for i, row in enumerate(matrix):
for j, logit in enumerate(row):
if i <= j and idx_best_end[b_i][i] == j and idx_best_beg[b_i][j] == i:
if logit > t_p:
spans_item.append([i, j])
if logit > max_logit:
max_logit, max_i, max_j = logit, i, j
if len(spans_item) == 0 and self.force_answer:
spans_item.append([max_i, max_j])
spans.append(spans_item)
return spans, None
def split_sequence_like(sequence_output, context_ranges):
"""Packs context tokens to the front (like your split_sequence with useSep=False),
returns padded H_c and context masks."""
bs, L, H = sequence_output.size()
H_c_batch, masks_batch = [], []
for b in range(bs):
c_beg, c_end = context_ranges[b]
ctx = sequence_output[b, c_beg:c_end+1, :]
Lc = ctx.size(0)
H_c_pad = torch.zeros(L, H, device=sequence_output.device)
H_c_pad[:Lc] = ctx
mask = torch.zeros(L, device=sequence_output.device)
mask[:Lc] = 1
H_c_batch.append(H_c_pad)
masks_batch.append(mask)
return torch.stack(H_c_batch, 0), torch.stack(masks_batch, 0)
def build_single_features(tokenizer, question, context, max_len=512, device=DEVICE):
enc = tokenizer(question, context, return_tensors="pt",
return_offsets_mapping=True, truncation=True, max_length=max_len)
input_ids = enc["input_ids"].to(device)
attn_mask = enc["attention_mask"].to(device)
type_ids = enc.get("token_type_ids")
if type_ids is None:
type_ids = torch.zeros_like(input_ids)
type_ids = type_ids.to(device)
# get context token range using sequence_ids
seq_ids = tokenizer(question, context, return_offsets_mapping=True,
truncation=True, max_length=max_len).sequence_ids(0)
ctx_start = seq_ids.index(1) # first context token
ctx_end = len(seq_ids) - 1
for i in range(ctx_start, len(seq_ids)):
if seq_ids[i] is None:
ctx_end = i - 1
break
return input_ids, type_ids, attn_mask, (ctx_start, ctx_end), enc["offset_mapping"][0].tolist()
if __name__ == "__main__":
REPO_ID = "ivabojic/deberta-v3-base_MultiSpanQA"
BASE = "microsoft/deberta-v3-base"
TOKENIZER = AutoTokenizer.from_pretrained(BASE)
model = SpanQualifier(
base_model_name=BASE,
max_span_gap=8, # adjust per domain
dim2=64,
max_len=512,
vanilla=False,
force_answer=False,
).to(DEVICE)
model.eval()
# --- Download from the Hub ---
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
state = torch.load(ckpt_path, map_location="cpu")
state = state.get("model_state_dict", state) # handle both formats
missing, unexpected = model.load_state_dict(state, strict=False)
# --- Example Q & context ---
question = "Who sang it's my party and i'll cry if i want to in the eighties?"
context = (
"In 1981, a remake by British artists Dave Stewart and Barbara Gaskin "
"was a UK number one hit single for four weeks and was also a major hit "
"in Austria (#3), Germany (#3), the Netherlands (#20), New Zealand (#1), "
"South Africa (#3) and Switzerland (#6). The track reached #72 in the US. "
"This was the first version of the song to reach #1 in the UK. The video "
"for the Stewart/Gaskin version contained a cameo by Thomas Dolby as Johnny, "
"Judy being played by Gaskin in a blond wig."
)
# --- Build features & run ---
input_ids, type_ids, attn_mask, ctx_range, offsets = build_single_features(
TOKENIZER, question, context, max_len=512, device=DEVICE
)
with torch.no_grad():
spans_rel = model(
input_ids=input_ids,
type_ids=type_ids,
mask_ids=attn_mask,
context_ranges=[ctx_range],
)[0]
# map relative spans back to text using absolute offsets
ctx_start, ctx_end = ctx_range
answers = []
for beg_rel, end_rel in spans_rel:
beg_abs = ctx_start + beg_rel
end_abs = ctx_start + end_rel
s_char = offsets[beg_abs][0]
e_char = offsets[end_abs][1]
answers.append(context[s_char:e_char].strip())
print("\nPredicted spans:")
for i, a in enumerate(answers, 1):
print(f"{i}. {a}")