zhiyuw commited on
Commit
68de276
·
verified ·
1 Parent(s): ae5f882

Trim ICLR mention, drop std from tables, add full how-to-use walkthrough

Browse files
Files changed (1) hide show
  1. README.md +127 -51
README.md CHANGED
@@ -28,10 +28,9 @@ embeddings.
28
 
29
  This repository hosts the trained **PLASMA** heads for every (task, backbone)
30
  combination from the paper, plus instructions for the parameter-free
31
- **PLASMA-PF** baseline (which has no learned weights). PLASMA was published at
32
- **ICLR 2026**.
33
 
34
- - **Paper:** <https://arxiv.org/abs/2510.11752> (ICLR 2026)
35
  - **Code:** <https://github.com/ZW471/PLASMA-Protein-Local-Alignment>
36
  - **License:** MIT
37
 
@@ -61,10 +60,20 @@ All heads share the same architecture: a small `LRL` non-linearity
61
  parameter-free Sinkhorn iteration (`temperature=0.1`, `n_iters=20`). The
62
  checkpoint files are ~3 MB each.
63
 
64
- ## Quickstart
65
 
66
- Install the PLASMA package from source (the model class is shipped with the
67
- GitHub repo):
 
 
 
 
 
 
 
 
 
 
68
 
69
  ```bash
70
  git clone https://github.com/ZW471/PLASMA-Protein-Local-Alignment
@@ -72,83 +81,150 @@ cd PLASMA-Protein-Local-Alignment
72
  uv sync
73
  ```
74
 
75
- Then load any trained head with the high-level helper:
 
 
 
76
 
77
  ```python
78
- import torch
79
  from alignment import load_plasma
80
 
81
- model = load_plasma(task="active_site", backbone="prot_bert")
 
 
82
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Feed pre-computed AA-level embeddings from the matching backbone.
85
- # H_q / H_c are residue-level embeddings; batch_q / batch_c assign each
86
- # residue to a sample (use zeros if you only have one pair).
87
- H_q = torch.randn(120, 1024) # query: 120 residues, ProtBERT dim
88
- H_c = torch.randn(180, 1024) # candidate: 180 residues
89
- batch_q = torch.zeros(120, dtype=torch.long)
90
- batch_c = torch.zeros(180, dtype=torch.long)
 
 
 
 
 
 
 
91
 
92
  with torch.no_grad():
93
- alignment_matrix = model(H_q, H_c, batch_q, batch_c) # (120, 180)
 
 
 
 
94
  ```
95
 
96
- The output is a doubly-stochastic transport plan describing the residue-level
97
- correspondence between the two substructures. To reduce it to a similarity
98
- score, reuse `utils.alignment_score` from the GitHub repo (it applies the
99
- diagonal convolution + threshold described in the paper).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  ## PLASMA-PF (parameter-free)
102
 
103
- PLASMA-PF is a hinge / Sinkhorn baseline with **no learned weights**. There is
104
- nothing to download just instantiate it from the same `Alignment` class:
 
105
 
106
  ```python
107
  from alignment import load_plasma_pf
108
 
109
- model = load_plasma_pf() # Alignment(eta='hinge', omega='sinkhorn', ...)
 
 
 
110
  ```
111
 
112
- It accepts the same forward signature as the trained heads above.
 
113
 
114
  ## Available variants & evaluation results
115
 
116
- Numbers below are 3-seed averages (mean ± std) reported in the paper. The seven
117
- backbone columns correspond to the seven subfolders under each task.
 
118
 
119
  ### Interpolation (in-distribution test split)
120
 
121
  | Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
122
  | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
123
- | **Motif** | ROC-AUC | .925 ± .002 | .933 ± .005 | .954 ± .002 | .854 ± .003 | .922 ± .002 | **.972 ± .001** | .910 ± .003 |
124
- | | F1-Max | .885 ± .002 | .877 ± .005 | .885 ± .003 | .784 ± .002 | .866 ± .002 | **.918 ± .003** | .853 ± .003 |
125
- | | PR-AUC | .921 ± .002 | .931 ± .004 | .953 ± .003 | .872 ± .003 | .920 ± .002 | **.971 ± .002** | .914 ± .003 |
126
- | | Label Match Score | .921 ± .004 | .890 ± .008 | .929 ± .001 | .746 ± .007 | .767 ± .008 | **.937 ± .001** | .792 ± .008 |
127
- | **Binding Site** | ROC-AUC | **.995 ± .000** | .992 ± .000 | .993 ± .001 | .981 ± .001 | .992 ± .001 | .993 ± .000 | .980 ± .001 |
128
- | | F1-Max | .987 ± .001 | .986 ± .001 | .983 ± .001 | .948 ± .002 | .982 ± .001 | **.988 ± .001** | .970 ± .001 |
129
- | | PR-AUC | **.996 ± .001** | .994 ± .001 | .995 ± .001 | .985 ± .001 | .993 ± .001 | .995 ± .000 | .984 ± .001 |
130
- | | Label Match Score | **.951 ± .002** | .950 ± .002 | **.951 ± .002** | .880 ± .008 | .872 ± .005 | **.951 ± .001** | .900 ± .004 |
131
- | **Active Site** | ROC-AUC | **.994 ± .001** | .991 ± .001 | .993 ± .001 | .986 ± .001 | .992 ± .001 | **.994 ± .001** | .991 ± .001 |
132
- | | F1-Max | **.989 ± .001** | .985 ± .001 | .987 ± .001 | .967 ± .001 | .987 ± .001 | .987 ± .001 | .982 ± .001 |
133
- | | PR-AUC | **.994 ± .001** | .992 ± .001 | **.994 ± .001** | .988 ± .001 | **.994 ± .001** | **.994 ± .001** | .992 ± .001 |
134
- | | Label Match Score | **.975 ± .001** | .969 ± .002 | **.975 ± .001** | .904 ± .003 | .885 ± .013 | .972 ± .001 | .938 ± .001 |
135
 
136
  ### Extrapolation (held-out hard test split)
137
 
138
  | Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
139
  | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
140
- | **Motif** | ROC-AUC | .960 ± .011 | .972 ± .010 | **.975 ± .009** | .870 ± .030 | .949 ± .013 | .968 ± .012 | .954 ± .013 |
141
- | | F1-Max | .915 ± .021 | **.931 ± .016** | .926 ± .020 | .799 ± .039 | .896 ± .023 | .922 ± .023 | .903 ± .026 |
142
- | | PR-AUC | .948 ± .020 | **.970 ± .010** | .969 ± .016 | .873 ± .036 | .940 ± .020 | .962 ± .018 | .944 ± .022 |
143
- | | Label Match Score | **.842 ± .025** | .786 ± .032 | .801 ± .022 | .541 ± .060 | .537 ± .025 | .738 ± .028 | .704 ± .020 |
144
- | **Binding Site** | ROC-AUC | .995 ± .005 | **.999 ± .001** | .993 ± .005 | .951 ± .014 | **.999 ± .001** | **.999 ± .001** | .990 ± .008 |
145
- | | F1-Max | .992 ± .005 | .991 ± .005 | .985 ± .009 | .896 ± .019 | .988 ± .006 | **.996 ± .003** | .983 ± .011 |
146
- | | PR-AUC | .997 ± .003 | **.999 ± .001** | .995 ± .003 | .958 ± .012 | .998 ± .001 | **.999 ± .000** | .992 ± .006 |
147
- | | Label Match Score | .894 ± .026 | .851 ± .031 | .891 ± .029 | .603 ± .041 | .753 ± .041 | **.902 ± .019** | .824 ± .031 |
148
- | **Active Site** | ROC-AUC | .995 ± .002 | .996 ± .003 | .996 ± .003 | .980 ± .004 | .997 ± .001 | **.999 ± .000** | .995 ± .002 |
149
- | | F1-Max | **.992 ± .002** | .986 ± .004 | .991 ± .004 | .950 ± .005 | .991 ± .003 | .991 ± .002 | .985 ± .003 |
150
- | | PR-AUC | .995 ± .003 | .997 ± .002 | .997 ± .002 | .984 ± .003 | .998 ± .001 | **.999 ± .000** | .996 ± .002 |
151
- | | Label Match Score | **.938 ± .014** | .882 ± .027 | .931 ± .026 | .697 ± .019 | .737 ± .011 | .893 ± .017 | .880 ± .023 |
152
 
153
  Each subfolder also contains a `metadata.json` with the full hyperparameter
154
  config in machine-readable form.
 
28
 
29
  This repository hosts the trained **PLASMA** heads for every (task, backbone)
30
  combination from the paper, plus instructions for the parameter-free
31
+ **PLASMA-PF** baseline (which has no learned weights).
 
32
 
33
+ - **Paper:** <https://arxiv.org/abs/2510.11752>
34
  - **Code:** <https://github.com/ZW471/PLASMA-Protein-Local-Alignment>
35
  - **License:** MIT
36
 
 
60
  parameter-free Sinkhorn iteration (`temperature=0.1`, `n_iters=20`). The
61
  checkpoint files are ~3 MB each.
62
 
63
+ ## How to use
64
 
65
+ PLASMA is a *head*: it consumes per-residue embeddings from a frozen protein
66
+ language model and returns a soft alignment matrix between two
67
+ sub-structures. The end-to-end pipeline is therefore three steps:
68
+
69
+ 1. Embed each protein with the backbone the head was trained on (one of the
70
+ seven listed above).
71
+ 2. Run the PLASMA head on the (residue × residue) embeddings to get a soft
72
+ alignment matrix `M ∈ [0, 1]^{n_q × n_c}`.
73
+ 3. Optionally reduce `M` to a scalar similarity score with
74
+ `utils.alignment_score`.
75
+
76
+ ### 1. Install
77
 
78
  ```bash
79
  git clone https://github.com/ZW471/PLASMA-Protein-Local-Alignment
 
81
  uv sync
82
  ```
83
 
84
+ The `Alignment` class and the `load_plasma` helper live in the `alignment`
85
+ package shipped by that repo.
86
+
87
+ ### 2. Load a trained head
88
 
89
  ```python
 
90
  from alignment import load_plasma
91
 
92
+ # task ∈ {"active_site", "binding_site", "motif"}
93
+ # backbone is the PLM whose embeddings the head was trained on
94
+ model = load_plasma(task="active_site", backbone="esm2_t33_650M_UR50D")
95
  model.eval()
96
+ ```
97
+
98
+ `load_plasma` downloads the matching `config.json` + `model.safetensors` from
99
+ this repo via `huggingface_hub` and rebuilds the `Alignment` module.
100
+
101
+ ### 3. Compute embeddings with the matching backbone
102
+
103
+ PLASMA does not embed sequences itself. The example below shows how to do it
104
+ with **ESM-2** via `transformers`; the same pattern works for any other
105
+ backbone (`Ankh`, `ProstT5`, `ProtBERT`, `ProtT5`, `TM-Vec`, `ProtSSN` —
106
+ their loaders are documented in `embed.py` in the GitHub repo).
107
+
108
+ ```python
109
+ import torch
110
+ from transformers import AutoTokenizer, AutoModel
111
+
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
115
+ backbone = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval()
116
+
117
+ @torch.no_grad()
118
+ def embed(sequence: str) -> torch.Tensor:
119
+ """Return per-residue embeddings of shape (L, 1280) — no special tokens."""
120
+ tokens = tokenizer(sequence, return_tensors="pt", add_special_tokens=True).to(device)
121
+ h = backbone(**tokens).last_hidden_state[0] # (L+2, 1280): <cls> ... <eos>
122
+ return h[1:-1].cpu() # drop <cls> and <eos>
123
 
124
+ seq_q = "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"
125
+ seq_c = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNI"
126
+
127
+ H_q = embed(seq_q) # (n_q, 1280)
128
+ H_c = embed(seq_c) # (n_c, 1280)
129
+ ```
130
+
131
+ ### 4. Run PLASMA and read the alignment matrix
132
+
133
+ ```python
134
+ # `batch_q` / `batch_c` assign each residue to a sample. Use zeros for a
135
+ # single pair; use [0, 0, ..., 1, 1, ...] to score multiple pairs in one batch.
136
+ batch_q = torch.zeros(H_q.size(0), dtype=torch.long)
137
+ batch_c = torch.zeros(H_c.size(0), dtype=torch.long)
138
 
139
  with torch.no_grad():
140
+ M = model(H_q, H_c, batch_q, batch_c) # (n_q, n_c) in [0, 1]
141
+
142
+ # Hard residue-residue assignment (top of column / row in the transport plan)
143
+ q_to_c = M.argmax(dim=1) # for each query residue, the best candidate residue
144
+ c_to_q = M.argmax(dim=0) # for each candidate residue, the best query residue
145
  ```
146
 
147
+ `M` is a (near-)doubly-stochastic transport plan: rows and columns each sum
148
+ to ~1, so `M[i, j]` is the soft probability that query residue `i` aligns to
149
+ candidate residue `j`. Thresholding at `0.5` gives a sparse local alignment;
150
+ plotting `M` as a heatmap gives the canonical PLASMA visualisation (the
151
+ diagonal stripe in the visual abstract above).
152
+
153
+ ### 5. Reduce to a similarity score
154
+
155
+ To collapse the alignment matrix into a single number per protein pair (the
156
+ quantity used to compute ROC-AUC / F1-Max in the tables above), use
157
+ `utils.alignment_score` from the GitHub repo. It applies the diagonal
158
+ convolution + thresholding described in the paper:
159
+
160
+ ```python
161
+ from utils.alignment_utils import alignment_score
162
+
163
+ score = alignment_score(
164
+ H_q, H_c, M, batch_c,
165
+ threshold=0.5, # gating on max-row / max-col residues
166
+ K=10, # diagonal-convolution window
167
+ ) # -> shape (num_pairs_in_batch,), here (1,)
168
+ print(float(score))
169
+ ```
170
 
171
  ## PLASMA-PF (parameter-free)
172
 
173
+ PLASMA-PF is a hinge / Sinkhorn baseline with **no learned weights**. Use it
174
+ when you want a strong zero-training baseline on top of any backbone — there
175
+ is nothing to download:
176
 
177
  ```python
178
  from alignment import load_plasma_pf
179
 
180
+ model = load_plasma_pf() # Alignment(eta='hinge', omega='sinkhorn', ...)
181
+
182
+ with torch.no_grad():
183
+ M_pf = model(H_q, H_c, batch_q, batch_c)
184
  ```
185
 
186
+ It accepts the same forward signature as the trained heads above and pairs
187
+ with any of the seven supported backbones.
188
 
189
  ## Available variants & evaluation results
190
 
191
+ Numbers below are 3-seed averages reported in the paper. The seven backbone
192
+ columns correspond to the seven subfolders under each task. **Bold** marks the
193
+ best backbone for each row.
194
 
195
  ### Interpolation (in-distribution test split)
196
 
197
  | Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
198
  | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
199
+ | **Motif** | ROC-AUC | .925 | .933 | .954 | .854 | .922 | **.972** | .910 |
200
+ | | F1-Max | .885 | .877 | .885 | .784 | .866 | **.918** | .853 |
201
+ | | PR-AUC | .921 | .931 | .953 | .872 | .920 | **.971** | .914 |
202
+ | | Label Match Score | .921 | .890 | .929 | .746 | .767 | **.937** | .792 |
203
+ | **Binding Site** | ROC-AUC | **.995** | .992 | .993 | .981 | .992 | .993 | .980 |
204
+ | | F1-Max | .987 | .986 | .983 | .948 | .982 | **.988** | .970 |
205
+ | | PR-AUC | **.996** | .994 | .995 | .985 | .993 | .995 | .984 |
206
+ | | Label Match Score | **.951** | .950 | **.951** | .880 | .872 | **.951** | .900 |
207
+ | **Active Site** | ROC-AUC | **.994** | .991 | .993 | .986 | .992 | **.994** | .991 |
208
+ | | F1-Max | **.989** | .985 | .987 | .967 | .987 | .987 | .982 |
209
+ | | PR-AUC | **.994** | .992 | **.994** | .988 | **.994** | **.994** | .992 |
210
+ | | Label Match Score | **.975** | .969 | **.975** | .904 | .885 | .972 | .938 |
211
 
212
  ### Extrapolation (held-out hard test split)
213
 
214
  | Task | Metric | Ankh | ESM-2 | ProstT5 | ProtBERT | ProtSSN | ProtT5 | TM-Vec |
215
  | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
216
+ | **Motif** | ROC-AUC | .960 | .972 | **.975** | .870 | .949 | .968 | .954 |
217
+ | | F1-Max | .915 | **.931** | .926 | .799 | .896 | .922 | .903 |
218
+ | | PR-AUC | .948 | **.970** | .969 | .873 | .940 | .962 | .944 |
219
+ | | Label Match Score | **.842** | .786 | .801 | .541 | .537 | .738 | .704 |
220
+ | **Binding Site** | ROC-AUC | .995 | **.999** | .993 | .951 | **.999** | **.999** | .990 |
221
+ | | F1-Max | .992 | .991 | .985 | .896 | .988 | **.996** | .983 |
222
+ | | PR-AUC | .997 | **.999** | .995 | .958 | .998 | **.999** | .992 |
223
+ | | Label Match Score | .894 | .851 | .891 | .603 | .753 | **.902** | .824 |
224
+ | **Active Site** | ROC-AUC | .995 | .996 | .996 | .980 | .997 | **.999** | .995 |
225
+ | | F1-Max | **.992** | .986 | .991 | .950 | .991 | .991 | .985 |
226
+ | | PR-AUC | .995 | .997 | .997 | .984 | .998 | **.999** | .996 |
227
+ | | Label Match Score | **.938** | .882 | .931 | .697 | .737 | .893 | .880 |
228
 
229
  Each subfolder also contains a `metadata.json` with the full hyperparameter
230
  config in machine-readable form.