DDI Risk Explorer · GINE + Co-Attention v2
Edge-aware Graph Isomorphism Network (GINE) + bilinear co-attention pair head, trained for drug-drug interaction (DDI) type classification on the DrugBank DDI dataset (86 classes).
Given two SMILES strings, returns the most likely interaction type (out of 86) and a calibrated softmax over all classes.
Research / educational use only. Not clinically validated. Not for medical decisions.
Files
| File | Purpose |
|---|---|
model.pt |
PyTorch checkpoint (state_dict + config + epoch) |
config.json |
Model hyperparameters (v2 architecture) |
label_map.json |
{class_id: interaction_text} for all 86 classes (1-indexed) |
results.json |
Held-out test metrics + per-class F1 + full training history |
Architecture
SMILES → RDKit Mol → PyG graph
atoms : 77-dim (atom type, degree, H, formal charge, hybridization,
chirality, radical e-, aromatic, in-ring)
bonds : 14-dim (bond type, conjugation, in-ring, stereo)
↓ GINEConv × 4 (256 hidden, ELU, BatchNorm, dropout 0.2)
↓ per-atom embeddings (128-d)
Two molecules' per-atom embeddings
↓ Bilinear co-attention S = X_a W X_b^T (SSI-DDI inspired)
↓ pool(X_a; X_b; ctx_a; ctx_b) → MLP → 86 logits
The encoder is shared between Drug A and Drug B (weight tying).
Test metrics (DrugBank DDI test split, 38,362 pairs)
| Metric | Value | vs. baseline |
|---|---|---|
| Top-1 accuracy | 96.47 % | +64.7 pp |
| Top-3 accuracy | 99.74 % | +68.0 pp |
| Macro-F1 | 92.88 % | balanced across 86 classes |
| Majority-class baseline | 31.72 % | — |
Independently computed on the standard PyTDC DrugBank DDI warm split (134,265 train / 19,181 valid / 38,362 test).
What changed since v1
| v1 (GAT) | v2 (this checkpoint) | |
|---|---|---|
| Atom features | 62-dim | 77-dim |
| Bond features | not used | 14-dim edge_attr consumed by GINE |
| Encoder | GATConv × 3, mean pool | GINEConv × 4, mean+max+sum pool |
| Pair head | concat(emb_a, emb_b) → MLP | Bilinear co-attention over per-atom embeddings |
| Loss | weighted cross-entropy | cross-entropy + 0.05 label smoothing |
| Optimizer | Adam, fixed lr 1e-3 | AdamW, warmup 5 ep + cosine to 200 ep |
| Top-1 / Top-3 / Macro-F1 | 39.27 / 76.24 / 42.03 | 96.47 / 99.74 / 92.88 |
How to use
import json, torch
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download("kareem-khaled/ddi-risk-explorer-gnn", "model.pt")
labels = json.load(open(hf_hub_download("kareem-khaled/ddi-risk-explorer-gnn", "label_map.json"), encoding="utf-8"))
# Model class lives in the source repo (or copy from the HF Space's model.py):
# github.com/kareemindata/ddi-risk-explorer → src/ddi_risk_explorer/models/ddi_model.py
from ddi_risk_explorer.models.ddi_model import DDIModel
from ddi_risk_explorer.features.graphs import smiles_to_graph
from torch_geometric.data import Batch
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = ckpt["model_config"]
model = DDIModel(
node_features=cfg["node_features"],
edge_features=cfg["edge_features"],
hidden_dim=cfg["hidden_dim"],
embedding_dim=cfg["embedding_dim"],
num_classes=cfg["num_classes"],
num_layers=cfg["num_layers"],
dropout=cfg["dropout"],
head=cfg["head"],
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
g_a = Batch.from_data_list([smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")]) # Aspirin
g_b = Batch.from_data_list([smiles_to_graph("CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O")]) # Ibuprofen
with torch.no_grad():
probs = torch.softmax(model(g_a, g_b), dim=-1)[0]
top3 = probs.topk(3)
for p, idx in zip(top3.values.tolist(), top3.indices.tolist()):
print(f"{p:.3f} {labels[str(idx + 1)]}") # labels are 1-indexed in source
For the easiest end-to-end usage, point a gradio_client at the live Space:
from gradio_client import Client
client = Client("kareem-khaled/ddi-risk-explorer")
client.predict(
smiles_a="CC(=O)Oc1ccccc1C(=O)O",
smiles_b="CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O",
name_a="Aspirin", name_b="Ibuprofen", top_k=5,
api_name="/predict",
)
Training data
DrugBank DDI via Therapeutics Data Commons — 191,808 pairs across 86 interaction types.
Limitations
- Not clinically validated. Research prototype only.
- Class imbalance. Rare interaction types (some with only a few dozen training examples) sit in the right tail of the per-class F1 distribution.
- Warm split only. The reported metrics share drugs across train / val / test; this measures interaction-type generalization, not novel-drug generalization. A cold-drug evaluation is on the roadmap.
- No confirmed-negative pairs. The model cannot distinguish "no interaction" from a rare type.
- Black-box. Co-attention weights are computed but not yet surfaced.
Source
Live demo: huggingface.co/spaces/kareem-khaled/ddi-risk-explorer Training code: github.com/kareemindata/ddi-risk-explorer
- Downloads last month
- 15