LTAF ECG Rhythm Classifier β€” RhythmResNet1D + TTA

A from-scratch 1D-ResNet trained on PhysioNet's Long-Term Atrial Fibrillation (LTAF) database for 6-class rhythm classification on two-lead 128 Hz ECG.

Metric Single-window + 7-view TTA (recommended)
Test accuracy 0.636 0.684
Test balanced accuracy 0.740 0.778
Test macro F1 0.614 0.656

vs. frozen Chronos-2 + MLP baseline on the same 6-class subset: test macro F1 = 0.299 β€” i.e. +36 pp / 2.2Γ— the F1.

Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.

Classes

Code Expansion
NSR Normal sinus rhythm
AFIB Atrial fibrillation
SBR Sinus bradycardia (<60 bpm, sinus origin)
AB Atrial bigeminy (every other beat is an APC)
SVTA Supraventricular tachyarrhythmia (β‰₯3 consec SV ectopics @ >100 bpm)
B Ventricular bigeminy (every other beat is a PVC)

VT, T, and IVR are excluded β€” their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.

Quickstart

pip install torch huggingface_hub numpy
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES

# Download checkpoint + model code from HF
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
model = RhythmResNet1D.load(ckpt, device="cuda")
model.eval()

# Input: (B, 2, 1280) β€” 10 s @ 128 Hz, 2 leads, per-channel z-scored.
x = torch.randn(1, 2, 1280).cuda()  # replace with real ECG
with torch.no_grad():
    logits = model(x)
    pred_idx = logits.argmax(-1).item()
print(model.class_names[pred_idx])

For best results, use the 7-view TTA wrapper in inference.py (averages softmax across 7 random window-start offsets β€” adds ~4 pp F1 at the cost of 7Γ— inference compute).

python inference.py

Architecture

RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64, blocks_per_stage=2):

  • Stem: Conv1d(2, 64, k=15, stride=2) β†’ BN β†’ ReLU β†’ MaxPool(2).
  • 4 ResNet stages Γ— 2 basic blocks (Conv1d k=7, BN, ReLU, Dropout, +skip). Channels: 64 β†’ 128 β†’ 256 β†’ 512. Time downsamples 2Γ— at the start of each stage past the first.
  • Head: AdaptiveAvgPool1d β†’ Linear(512 β†’ 128) β†’ ReLU β†’ Dropout(0.2) β†’ Linear(128 β†’ 6).
  • Total parameters: 8,794,246.

Input format

  • (B, 2, 1280) float32
  • 2-lead ECG at 128 Hz (LTAF leads ECG1, ECG2)
  • 10 s window
  • Per-channel z-scored: (x - x.mean(axis=-1)) / x.std(axis=-1)

Test-time augmentation (TTA)

Pass a longer signal slice (β‰₯1280 samples) to predict_tta() and it samples 7 random 10 s windows, averages the softmax outputs, then argmaxes. Why it helps: training uses random window-start sampling within each rhythm bout, so the model learns to be invariant to that shift. At eval time, taking multiple shifts and averaging cancels the position-specific noise. +4.2 pp test macro F1, no retraining.

# (2, 30*128) signal, 30 s long
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")

Training recipe

.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
    --arch resnet1d --window-sizes 10 \
    --epochs 30 --batch-size 64 --lr 5e-4 \
    --base-channels 64 \
    --use-val-as-train \
    --classes NSR AFIB SBR AB SVTA B \
    --output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
  • Dataset: LTAF train+val combined (75 records). 8 records held out for early stopping. Test (9 records, 3,716 windows) untouched.
  • Loss: weighted cross-entropy with sqrt-dampened inverse-frequency class weights (cap 10), label smoothing 0.1.
  • Cosine LR schedule from 5e-4 β†’ 0 over 30 epochs. AdamW (wd 1e-4).
  • Best checkpoint by held-out macro F1.
  • Training time on a single H100 80GB: ~6 minutes.

Source repo: scripts/train_ecg_rhythm_scratch.py and src/models/ts_llm/ecg_rhythm_scratch.py in rmxjck/TSLM-Arena.

Test set details

LTAF held-out split (deterministic seed 42, record-level): 9 records (100, 104, 105, 11, 200, 32, 48, 49, 68), 3,716 windows.

Confusion matrix (rows = true, cols = pred), with TTA:

NSR AFIB SBR AB SVTA B
NSR 1109 286 95 114 185 35
AFIB 189 628 25 29 294 14
SBR 26 0 279 0 0 0
AB 9 13 0 225 3 0
SVTA 9 14 0 3 34 0
B 4 0 0 1 3 90

Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.

What was tried and didn't help

This model was the best of 30+ experiments. What did not improve over this baseline:

  • HRV side-channel input (8-dim RR-derived features fused with CNN trunk): hurts F1 by 3-8 pp because the CNN already extracts equivalent information from raw QRS timing.
  • Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model toward over-calling AFIB on LTAF's paroxysmal transitions.
  • Wider models (96-channel, 12 M params): overfits.
  • Longer training (50 epochs): overfits.
  • Multi-model soft-voting ensembles: members make correlated errors.
  • Focal loss: matches CE within noise.
  • Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
  • Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M): underperform a 2.2 M home-rolled ResNet1D at 12 epochs.

Not for clinical use

Research artifact only. Not FDA-cleared. Not suitable for triage, diagnosis, or any patient-facing application. Uses the LTAF benchmark which has known label noise from its original PhysioNet curation.

Citation

@misc{petrutiu2008ltafdb,
  title         = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
  author        = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
  year          = {2008},
  howpublished  = {PhysioNet},
  url           = {https://physionet.org/content/ltafdb/}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support