Trim ICLR mention, drop std from tables, add full how-to-use walkthrough
Browse files
README.md
CHANGED
|
@@ -28,10 +28,9 @@ embeddings.
|
|
| 28 |
|
| 29 |
This repository hosts the trained **PLASMA** heads for every (task, backbone)
|
| 30 |
combination from the paper, plus instructions for the parameter-free
|
| 31 |
-
**PLASMA-PF** baseline (which has no learned weights).
|
| 32 |
-
**ICLR 2026**.
|
| 33 |
|
| 34 |
-
- **Paper:** <https://arxiv.org/abs/2510.11752>
|
| 35 |
- **Code:** <https://github.com/ZW471/PLASMA-Protein-Local-Alignment>
|
| 36 |
- **License:** MIT
|
| 37 |
|
|
@@ -61,10 +60,20 @@ All heads share the same architecture: a small `LRL` non-linearity
|
|
| 61 |
parameter-free Sinkhorn iteration (`temperature=0.1`, `n_iters=20`). The
|
| 62 |
checkpoint files are ~3 MB each.
|
| 63 |
|
| 64 |
-
##
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
```bash
|
| 70 |
git clone https://github.com/ZW471/PLASMA-Protein-Local-Alignment
|
|
@@ -72,83 +81,150 @@ cd PLASMA-Protein-Local-Alignment
|
|
| 72 |
uv sync
|
| 73 |
```
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
```python
|
| 78 |
-
import torch
|
| 79 |
from alignment import load_plasma
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
H_q =
|
| 88 |
-
H_c =
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
with torch.no_grad():
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
```
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
## PLASMA-PF (parameter-free)
|
| 102 |
|
| 103 |
-
PLASMA-PF is a hinge / Sinkhorn baseline with **no learned weights**.
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
```python
|
| 107 |
from alignment import load_plasma_pf
|
| 108 |
|
| 109 |
-
model = load_plasma_pf()
|
|
|
|
|
|
|
|
|
|
| 110 |
```
|
| 111 |
|
| 112 |
-
It accepts the same forward signature as the trained heads above
|
|
|
|
| 113 |
|
| 114 |
## Available variants & evaluation results
|
| 115 |
|
| 116 |
-
Numbers below are 3-seed averages
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
### Interpolation (in-distribution test split)
|
| 120 |
|
| 121 |
| Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
|
| 122 |
| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 123 |
-
| **Motif** | ROC-AUC | .925
|
| 124 |
-
| | F1-Max | .885
|
| 125 |
-
| | PR-AUC | .921
|
| 126 |
-
| | Label Match Score | .921
|
| 127 |
-
| **Binding Site** | ROC-AUC | **.995
|
| 128 |
-
| | F1-Max | .987
|
| 129 |
-
| | PR-AUC | **.996
|
| 130 |
-
| | Label Match Score | **.951
|
| 131 |
-
| **Active Site** | ROC-AUC | **.994
|
| 132 |
-
| | F1-Max | **.989
|
| 133 |
-
| | PR-AUC | **.994
|
| 134 |
-
| | Label Match Score | **.975
|
| 135 |
|
| 136 |
### Extrapolation (held-out hard test split)
|
| 137 |
|
| 138 |
| Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
|
| 139 |
| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 140 |
-
| **Motif** | ROC-AUC | .960
|
| 141 |
-
| | F1-Max | .915
|
| 142 |
-
| | PR-AUC | .948
|
| 143 |
-
| | Label Match Score | **.842
|
| 144 |
-
| **Binding Site** | ROC-AUC | .995
|
| 145 |
-
| | F1-Max | .992
|
| 146 |
-
| | PR-AUC | .997
|
| 147 |
-
| | Label Match Score | .894
|
| 148 |
-
| **Active Site** | ROC-AUC | .995
|
| 149 |
-
| | F1-Max | **.992
|
| 150 |
-
| | PR-AUC | .995
|
| 151 |
-
| | Label Match Score | **.938
|
| 152 |
|
| 153 |
Each subfolder also contains a `metadata.json` with the full hyperparameter
|
| 154 |
config in machine-readable form.
|
|
|
|
| 28 |
|
| 29 |
This repository hosts the trained **PLASMA** heads for every (task, backbone)
|
| 30 |
combination from the paper, plus instructions for the parameter-free
|
| 31 |
+
**PLASMA-PF** baseline (which has no learned weights).
|
|
|
|
| 32 |
|
| 33 |
+
- **Paper:** <https://arxiv.org/abs/2510.11752>
|
| 34 |
- **Code:** <https://github.com/ZW471/PLASMA-Protein-Local-Alignment>
|
| 35 |
- **License:** MIT
|
| 36 |
|
|
|
|
| 60 |
parameter-free Sinkhorn iteration (`temperature=0.1`, `n_iters=20`). The
|
| 61 |
checkpoint files are ~3 MB each.
|
| 62 |
|
| 63 |
+
## How to use
|
| 64 |
|
| 65 |
+
PLASMA is a *head*: it consumes per-residue embeddings from a frozen protein
|
| 66 |
+
language model and returns a soft alignment matrix between two
|
| 67 |
+
sub-structures. The end-to-end pipeline is therefore three steps:
|
| 68 |
+
|
| 69 |
+
1. Embed each protein with the backbone the head was trained on (one of the
|
| 70 |
+
seven listed above).
|
| 71 |
+
2. Run the PLASMA head on the (residue × residue) embeddings to get a soft
|
| 72 |
+
alignment matrix `M ∈ [0, 1]^{n_q × n_c}`.
|
| 73 |
+
3. Optionally reduce `M` to a scalar similarity score with
|
| 74 |
+
`utils.alignment_score`.
|
| 75 |
+
|
| 76 |
+
### 1. Install
|
| 77 |
|
| 78 |
```bash
|
| 79 |
git clone https://github.com/ZW471/PLASMA-Protein-Local-Alignment
|
|
|
|
| 81 |
uv sync
|
| 82 |
```
|
| 83 |
|
| 84 |
+
The `Alignment` class and the `load_plasma` helper live in the `alignment`
|
| 85 |
+
package shipped by that repo.
|
| 86 |
+
|
| 87 |
+
### 2. Load a trained head
|
| 88 |
|
| 89 |
```python
|
|
|
|
| 90 |
from alignment import load_plasma
|
| 91 |
|
| 92 |
+
# task ∈ {"active_site", "binding_site", "motif"}
|
| 93 |
+
# backbone is the PLM whose embeddings the head was trained on
|
| 94 |
+
model = load_plasma(task="active_site", backbone="esm2_t33_650M_UR50D")
|
| 95 |
model.eval()
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
`load_plasma` downloads the matching `config.json` + `model.safetensors` from
|
| 99 |
+
this repo via `huggingface_hub` and rebuilds the `Alignment` module.
|
| 100 |
+
|
| 101 |
+
### 3. Compute embeddings with the matching backbone
|
| 102 |
+
|
| 103 |
+
PLASMA does not embed sequences itself. The example below shows how to do it
|
| 104 |
+
with **ESM-2** via `transformers`; the same pattern works for any other
|
| 105 |
+
backbone (`Ankh`, `ProstT5`, `ProtBERT`, `ProtT5`, `TM-Vec`, `ProtSSN` —
|
| 106 |
+
their loaders are documented in `embed.py` in the GitHub repo).
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
import torch
|
| 110 |
+
from transformers import AutoTokenizer, AutoModel
|
| 111 |
+
|
| 112 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 113 |
+
|
| 114 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 115 |
+
backbone = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval()
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def embed(sequence: str) -> torch.Tensor:
|
| 119 |
+
"""Return per-residue embeddings of shape (L, 1280) — no special tokens."""
|
| 120 |
+
tokens = tokenizer(sequence, return_tensors="pt", add_special_tokens=True).to(device)
|
| 121 |
+
h = backbone(**tokens).last_hidden_state[0] # (L+2, 1280): <cls> ... <eos>
|
| 122 |
+
return h[1:-1].cpu() # drop <cls> and <eos>
|
| 123 |
|
| 124 |
+
seq_q = "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"
|
| 125 |
+
seq_c = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNI"
|
| 126 |
+
|
| 127 |
+
H_q = embed(seq_q) # (n_q, 1280)
|
| 128 |
+
H_c = embed(seq_c) # (n_c, 1280)
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### 4. Run PLASMA and read the alignment matrix
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
# `batch_q` / `batch_c` assign each residue to a sample. Use zeros for a
|
| 135 |
+
# single pair; use [0, 0, ..., 1, 1, ...] to score multiple pairs in one batch.
|
| 136 |
+
batch_q = torch.zeros(H_q.size(0), dtype=torch.long)
|
| 137 |
+
batch_c = torch.zeros(H_c.size(0), dtype=torch.long)
|
| 138 |
|
| 139 |
with torch.no_grad():
|
| 140 |
+
M = model(H_q, H_c, batch_q, batch_c) # (n_q, n_c) in [0, 1]
|
| 141 |
+
|
| 142 |
+
# Hard residue-residue assignment (top of column / row in the transport plan)
|
| 143 |
+
q_to_c = M.argmax(dim=1) # for each query residue, the best candidate residue
|
| 144 |
+
c_to_q = M.argmax(dim=0) # for each candidate residue, the best query residue
|
| 145 |
```
|
| 146 |
|
| 147 |
+
`M` is a (near-)doubly-stochastic transport plan: rows and columns each sum
|
| 148 |
+
to ~1, so `M[i, j]` is the soft probability that query residue `i` aligns to
|
| 149 |
+
candidate residue `j`. Thresholding at `0.5` gives a sparse local alignment;
|
| 150 |
+
plotting `M` as a heatmap gives the canonical PLASMA visualisation (the
|
| 151 |
+
diagonal stripe in the visual abstract above).
|
| 152 |
+
|
| 153 |
+
### 5. Reduce to a similarity score
|
| 154 |
+
|
| 155 |
+
To collapse the alignment matrix into a single number per protein pair (the
|
| 156 |
+
quantity used to compute ROC-AUC / F1-Max in the tables above), use
|
| 157 |
+
`utils.alignment_score` from the GitHub repo. It applies the diagonal
|
| 158 |
+
convolution + thresholding described in the paper:
|
| 159 |
+
|
| 160 |
+
```python
|
| 161 |
+
from utils.alignment_utils import alignment_score
|
| 162 |
+
|
| 163 |
+
score = alignment_score(
|
| 164 |
+
H_q, H_c, M, batch_c,
|
| 165 |
+
threshold=0.5, # gating on max-row / max-col residues
|
| 166 |
+
K=10, # diagonal-convolution window
|
| 167 |
+
) # -> shape (num_pairs_in_batch,), here (1,)
|
| 168 |
+
print(float(score))
|
| 169 |
+
```
|
| 170 |
|
| 171 |
## PLASMA-PF (parameter-free)
|
| 172 |
|
| 173 |
+
PLASMA-PF is a hinge / Sinkhorn baseline with **no learned weights**. Use it
|
| 174 |
+
when you want a strong zero-training baseline on top of any backbone — there
|
| 175 |
+
is nothing to download:
|
| 176 |
|
| 177 |
```python
|
| 178 |
from alignment import load_plasma_pf
|
| 179 |
|
| 180 |
+
model = load_plasma_pf() # Alignment(eta='hinge', omega='sinkhorn', ...)
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
M_pf = model(H_q, H_c, batch_q, batch_c)
|
| 184 |
```
|
| 185 |
|
| 186 |
+
It accepts the same forward signature as the trained heads above and pairs
|
| 187 |
+
with any of the seven supported backbones.
|
| 188 |
|
| 189 |
## Available variants & evaluation results
|
| 190 |
|
| 191 |
+
Numbers below are 3-seed averages reported in the paper. The seven backbone
|
| 192 |
+
columns correspond to the seven subfolders under each task. **Bold** marks the
|
| 193 |
+
best backbone for each row.
|
| 194 |
|
| 195 |
### Interpolation (in-distribution test split)
|
| 196 |
|
| 197 |
| Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
|
| 198 |
| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 199 |
+
| **Motif** | ROC-AUC | .925 | .933 | .954 | .854 | .922 | **.972** | .910 |
|
| 200 |
+
| | F1-Max | .885 | .877 | .885 | .784 | .866 | **.918** | .853 |
|
| 201 |
+
| | PR-AUC | .921 | .931 | .953 | .872 | .920 | **.971** | .914 |
|
| 202 |
+
| | Label Match Score | .921 | .890 | .929 | .746 | .767 | **.937** | .792 |
|
| 203 |
+
| **Binding Site** | ROC-AUC | **.995** | .992 | .993 | .981 | .992 | .993 | .980 |
|
| 204 |
+
| | F1-Max | .987 | .986 | .983 | .948 | .982 | **.988** | .970 |
|
| 205 |
+
| | PR-AUC | **.996** | .994 | .995 | .985 | .993 | .995 | .984 |
|
| 206 |
+
| | Label Match Score | **.951** | .950 | **.951** | .880 | .872 | **.951** | .900 |
|
| 207 |
+
| **Active Site** | ROC-AUC | **.994** | .991 | .993 | .986 | .992 | **.994** | .991 |
|
| 208 |
+
| | F1-Max | **.989** | .985 | .987 | .967 | .987 | .987 | .982 |
|
| 209 |
+
| | PR-AUC | **.994** | .992 | **.994** | .988 | **.994** | **.994** | .992 |
|
| 210 |
+
| | Label Match Score | **.975** | .969 | **.975** | .904 | .885 | .972 | .938 |
|
| 211 |
|
| 212 |
### Extrapolation (held-out hard test split)
|
| 213 |
|
| 214 |
| Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
|
| 215 |
| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 216 |
+
| **Motif** | ROC-AUC | .960 | .972 | **.975** | .870 | .949 | .968 | .954 |
|
| 217 |
+
| | F1-Max | .915 | **.931** | .926 | .799 | .896 | .922 | .903 |
|
| 218 |
+
| | PR-AUC | .948 | **.970** | .969 | .873 | .940 | .962 | .944 |
|
| 219 |
+
| | Label Match Score | **.842** | .786 | .801 | .541 | .537 | .738 | .704 |
|
| 220 |
+
| **Binding Site** | ROC-AUC | .995 | **.999** | .993 | .951 | **.999** | **.999** | .990 |
|
| 221 |
+
| | F1-Max | .992 | .991 | .985 | .896 | .988 | **.996** | .983 |
|
| 222 |
+
| | PR-AUC | .997 | **.999** | .995 | .958 | .998 | **.999** | .992 |
|
| 223 |
+
| | Label Match Score | .894 | .851 | .891 | .603 | .753 | **.902** | .824 |
|
| 224 |
+
| **Active Site** | ROC-AUC | .995 | .996 | .996 | .980 | .997 | **.999** | .995 |
|
| 225 |
+
| | F1-Max | **.992** | .986 | .991 | .950 | .991 | .991 | .985 |
|
| 226 |
+
| | PR-AUC | .995 | .997 | .997 | .984 | .998 | **.999** | .996 |
|
| 227 |
+
| | Label Match Score | **.938** | .882 | .931 | .697 | .737 | .893 | .880 |
|
| 228 |
|
| 229 |
Each subfolder also contains a `metadata.json` with the full hyperparameter
|
| 230 |
config in machine-readable form.
|