LTAF ECG Rhythm Classifier β RhythmResNet1D + TTA
A from-scratch 1D-ResNet trained on PhysioNet's Long-Term Atrial Fibrillation (LTAF) database for 6-class rhythm classification on two-lead 128 Hz ECG.
| Metric | Single-window | + 7-view TTA (recommended) |
|---|---|---|
| Test accuracy | 0.636 | 0.684 |
| Test balanced accuracy | 0.740 | 0.778 |
| Test macro F1 | 0.614 | 0.656 |
vs. frozen Chronos-2 + MLP baseline on the same 6-class subset: test macro F1 = 0.299 β i.e. +36 pp / 2.2Γ the F1.
Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.
Classes
| Code | Expansion |
|---|---|
| NSR | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
| AB | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β₯3 consec SV ectopics @ >100 bpm) |
| B | Ventricular bigeminy (every other beat is a PVC) |
VT, T, and IVR are excluded β their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.
Quickstart
pip install torch huggingface_hub numpy
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES
# Download checkpoint + model code from HF
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
model = RhythmResNet1D.load(ckpt, device="cuda")
model.eval()
# Input: (B, 2, 1280) β 10 s @ 128 Hz, 2 leads, per-channel z-scored.
x = torch.randn(1, 2, 1280).cuda() # replace with real ECG
with torch.no_grad():
logits = model(x)
pred_idx = logits.argmax(-1).item()
print(model.class_names[pred_idx])
For best results, use the 7-view TTA wrapper in inference.py
(averages softmax across 7 random window-start offsets β adds ~4 pp F1
at the cost of 7Γ inference compute).
python inference.py
Architecture
RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64, blocks_per_stage=2):
- Stem: Conv1d(2, 64, k=15, stride=2) β BN β ReLU β MaxPool(2).
- 4 ResNet stages Γ 2 basic blocks (Conv1d k=7, BN, ReLU, Dropout, +skip). Channels: 64 β 128 β 256 β 512. Time downsamples 2Γ at the start of each stage past the first.
- Head: AdaptiveAvgPool1d β Linear(512 β 128) β ReLU β Dropout(0.2) β Linear(128 β 6).
- Total parameters: 8,794,246.
Input format
(B, 2, 1280)float32- 2-lead ECG at 128 Hz (LTAF leads
ECG1,ECG2) - 10 s window
- Per-channel z-scored:
(x - x.mean(axis=-1)) / x.std(axis=-1)
Test-time augmentation (TTA)
Pass a longer signal slice (β₯1280 samples) to predict_tta() and it
samples 7 random 10 s windows, averages the softmax outputs, then
argmaxes. Why it helps: training uses random window-start sampling
within each rhythm bout, so the model learns to be invariant to that
shift. At eval time, taking multiple shifts and averaging cancels the
position-specific noise. +4.2 pp test macro F1, no retraining.
# (2, 30*128) signal, 30 s long
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
Training recipe
.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
--arch resnet1d --window-sizes 10 \
--epochs 30 --batch-size 64 --lr 5e-4 \
--base-channels 64 \
--use-val-as-train \
--classes NSR AFIB SBR AB SVTA B \
--output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
- Dataset: LTAF train+val combined (75 records). 8 records held out for early stopping. Test (9 records, 3,716 windows) untouched.
- Loss: weighted cross-entropy with sqrt-dampened inverse-frequency class weights (cap 10), label smoothing 0.1.
- Cosine LR schedule from 5e-4 β 0 over 30 epochs. AdamW (wd 1e-4).
- Best checkpoint by held-out macro F1.
- Training time on a single H100 80GB: ~6 minutes.
Source repo: scripts/train_ecg_rhythm_scratch.py and
src/models/ts_llm/ecg_rhythm_scratch.py in
rmxjck/TSLM-Arena.
Test set details
LTAF held-out split (deterministic seed 42, record-level): 9 records
(100, 104, 105, 11, 200, 32, 48, 49, 68), 3,716 windows.
Confusion matrix (rows = true, cols = pred), with TTA:
| NSR | AFIB | SBR | AB | SVTA | B | |
|---|---|---|---|---|---|---|
| NSR | 1109 | 286 | 95 | 114 | 185 | 35 |
| AFIB | 189 | 628 | 25 | 29 | 294 | 14 |
| SBR | 26 | 0 | 279 | 0 | 0 | 0 |
| AB | 9 | 13 | 0 | 225 | 3 | 0 |
| SVTA | 9 | 14 | 0 | 3 | 34 | 0 |
| B | 4 | 0 | 0 | 1 | 3 | 90 |
Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.
What was tried and didn't help
This model was the best of 30+ experiments. What did not improve over this baseline:
- HRV side-channel input (8-dim RR-derived features fused with CNN trunk): hurts F1 by 3-8 pp because the CNN already extracts equivalent information from raw QRS timing.
- Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model toward over-calling AFIB on LTAF's paroxysmal transitions.
- Wider models (96-channel, 12 M params): overfits.
- Longer training (50 epochs): overfits.
- Multi-model soft-voting ensembles: members make correlated errors.
- Focal loss: matches CE within noise.
- Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
- Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M): underperform a 2.2 M home-rolled ResNet1D at 12 epochs.
Not for clinical use
Research artifact only. Not FDA-cleared. Not suitable for triage, diagnosis, or any patient-facing application. Uses the LTAF benchmark which has known label noise from its original PhysioNet curation.
Citation
@misc{petrutiu2008ltafdb,
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
year = {2008},
howpublished = {PhysioNet},
url = {https://physionet.org/content/ltafdb/}
}