DDI Risk Explorer · GINE + Co-Attention v2

Edge-aware Graph Isomorphism Network (GINE) + bilinear co-attention pair head, trained for drug-drug interaction (DDI) type classification on the DrugBank DDI dataset (86 classes).

Given two SMILES strings, returns the most likely interaction type (out of 86) and a calibrated softmax over all classes.

Research / educational use only. Not clinically validated. Not for medical decisions.

Files

File Purpose
model.pt PyTorch checkpoint (state_dict + config + epoch)
config.json Model hyperparameters (v2 architecture)
label_map.json {class_id: interaction_text} for all 86 classes (1-indexed)
results.json Held-out test metrics + per-class F1 + full training history

Architecture

SMILES → RDKit Mol → PyG graph
   atoms : 77-dim   (atom type, degree, H, formal charge, hybridization,
                     chirality, radical e-, aromatic, in-ring)
   bonds : 14-dim   (bond type, conjugation, in-ring, stereo)

   ↓  GINEConv × 4   (256 hidden, ELU, BatchNorm, dropout 0.2)
   ↓  per-atom embeddings (128-d)

Two molecules' per-atom embeddings
   ↓  Bilinear co-attention   S = X_a W X_b^T   (SSI-DDI inspired)
   ↓  pool(X_a; X_b; ctx_a; ctx_b)   →   MLP   →   86 logits

The encoder is shared between Drug A and Drug B (weight tying).

Test metrics (DrugBank DDI test split, 38,362 pairs)

Metric Value vs. baseline
Top-1 accuracy 96.47 % +64.7 pp
Top-3 accuracy 99.74 % +68.0 pp
Macro-F1 92.88 % balanced across 86 classes
Majority-class baseline 31.72 %

Independently computed on the standard PyTDC DrugBank DDI warm split (134,265 train / 19,181 valid / 38,362 test).

What changed since v1

v1 (GAT) v2 (this checkpoint)
Atom features 62-dim 77-dim
Bond features not used 14-dim edge_attr consumed by GINE
Encoder GATConv × 3, mean pool GINEConv × 4, mean+max+sum pool
Pair head concat(emb_a, emb_b) → MLP Bilinear co-attention over per-atom embeddings
Loss weighted cross-entropy cross-entropy + 0.05 label smoothing
Optimizer Adam, fixed lr 1e-3 AdamW, warmup 5 ep + cosine to 200 ep
Top-1 / Top-3 / Macro-F1 39.27 / 76.24 / 42.03 96.47 / 99.74 / 92.88

How to use

import json, torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download("kareem-khaled/ddi-risk-explorer-gnn", "model.pt")
labels    = json.load(open(hf_hub_download("kareem-khaled/ddi-risk-explorer-gnn", "label_map.json"), encoding="utf-8"))

# Model class lives in the source repo (or copy from the HF Space's model.py):
#   github.com/kareemindata/ddi-risk-explorer  →  src/ddi_risk_explorer/models/ddi_model.py
from ddi_risk_explorer.models.ddi_model import DDIModel
from ddi_risk_explorer.features.graphs import smiles_to_graph
from torch_geometric.data import Batch

ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg  = ckpt["model_config"]
model = DDIModel(
    node_features=cfg["node_features"],
    edge_features=cfg["edge_features"],
    hidden_dim=cfg["hidden_dim"],
    embedding_dim=cfg["embedding_dim"],
    num_classes=cfg["num_classes"],
    num_layers=cfg["num_layers"],
    dropout=cfg["dropout"],
    head=cfg["head"],
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

g_a = Batch.from_data_list([smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")])             # Aspirin
g_b = Batch.from_data_list([smiles_to_graph("CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O")])    # Ibuprofen

with torch.no_grad():
    probs = torch.softmax(model(g_a, g_b), dim=-1)[0]

top3 = probs.topk(3)
for p, idx in zip(top3.values.tolist(), top3.indices.tolist()):
    print(f"{p:.3f}  {labels[str(idx + 1)]}")  # labels are 1-indexed in source

For the easiest end-to-end usage, point a gradio_client at the live Space:

from gradio_client import Client
client = Client("kareem-khaled/ddi-risk-explorer")
client.predict(
    smiles_a="CC(=O)Oc1ccccc1C(=O)O",
    smiles_b="CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O",
    name_a="Aspirin", name_b="Ibuprofen", top_k=5,
    api_name="/predict",
)

Training data

DrugBank DDI via Therapeutics Data Commons — 191,808 pairs across 86 interaction types.

Limitations

  • Not clinically validated. Research prototype only.
  • Class imbalance. Rare interaction types (some with only a few dozen training examples) sit in the right tail of the per-class F1 distribution.
  • Warm split only. The reported metrics share drugs across train / val / test; this measures interaction-type generalization, not novel-drug generalization. A cold-drug evaluation is on the roadmap.
  • No confirmed-negative pairs. The model cannot distinguish "no interaction" from a rare type.
  • Black-box. Co-attention weights are computed but not yet surfaced.

Source

Live demo: huggingface.co/spaces/kareem-khaled/ddi-risk-explorer Training code: github.com/kareemindata/ddi-risk-explorer

Downloads last month
15
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using kareem-khaled/ddi-risk-explorer-gnn 1