CLD β Convex Language-Detection Head for facebook/mms-1b-all (5 languages)
Convex Low-resource Accent-Robust Language Detection (CLD) head for multilingual
speech recognition. A lightweight convex ReLU-MLP spoken-language classifier trained
on frozen pooled encoder embeddings from
facebook/mms-1b-all. At inference it performs
spoken language detection over 5 languages and selects the matching MMS language adapter
before decoding, improving downstream transcription accuracy.
Model description
CLD attaches a small spoken language detection (5-way) head to the (frozen) encoder of a pre-trained ASR model and uses it to pick the language before decoding. Instead of a standard neural network, the head is a two-layer convex ReLU MLP trained by solving a convex reformulation with CRONOS β an ADMM solver (with preconditioned conjugate gradient / NystrΓΆm preconditioning) implemented in JAX. This yields a large training speedup over a standard NN head while matching or exceeding its accuracy, and is especially strong in low-resource and accent-diverse regimes.
- Backbone (frozen):
facebook/mms-1b-all - Task: spoken language detection (5-way) β ASR language selection
- Languages (5):
enβ Englishhiβ Hindiidβ Indonesianmsβ Malayzhβ Mandarin Chinese
- Class index β language:
0βen,1βhi,2βid,3βms,4βzh(labels are the sorted ISO-639-1 codes)
How to use
The head is loaded and run through the jaxcld
package. Loading the artifact requires JAX (the weights are JAX arrays):
pip install jaxcld jax
import numpy as np
from huggingface_hub import hf_hub_download
from cld import ASRModel, CVXNNLangDetectHead
languages = ["en", "hi", "id", "ms", "zh"]
# 1) Load the frozen base ASR model
asr = ASRModel.from_pretrained("facebook/mms-1b-all", config={"languages": languages})
# 2) Download + load this convex language-detection head
head_path = hf_hub_download("williamhtan/cld-mms-1b-5lang", "model.pkl")
head = CVXNNLangDetectHead.load(head_path, asr)
# 3) Attach and run
asr.set_lang_detect_head(head)
audio_16k_mono: np.ndarray = ... # shape (T,), 16 kHz mono
pred_langs, pred_texts = asr.predict(audio_16k_mono)
print(pred_langs[0], pred_texts[0])
Pair the head with the matching frozen base encoder β embeddings are backbone-specific, so a head is not transferable across backbones or to other languages.
Architecture
The head consumes mean-pooled encoder hidden states X β β^(BΓ1280)
(pooled over time) and computes, per class, logits = relu(X @ W1) @ W2, then takes the
argmax. The pickle stores a CVX_ReLU_MLP object whose key tensors are:
| Tensor | Shape | Role |
|---|---|---|
theta1 |
(5, 1280, 128) |
first-layer (ReLU) weights, per class |
theta2 |
(5, 128) |
output-layer weights, per class |
Configuration: n_classes = 5, P_S = 64 hyperplane samples (β 128 ReLU units),
input dim 1280.
Training
Trained with train_cvxnn.py (CRONOS / ADMM in JAX) on mean-pooled frozen
facebook/mms-1b-all encoder embeddings from Mozilla Common Voice, accent-stratified across
the five languages (~12,368 train / ~1,546 validation pooled embeddings).
| rank | neurons | beta | rho | gamma_ratio | admm_iters | pcg_iters | opt_seed | data_seed |
|---|---|---|---|---|---|---|---|---|
| 20 | 64 | 0.001 | 0.1 | 1 | 6 | 32 | 1024 | 8 |
Training time 739.2 s, estimated 170,744 TFLOPs.
Evaluation
Measured on the held-out 5-language test split (n=1546):
| Metric | Value |
|---|---|
| Detection accuracy | 0.96 |
| WER (β) | 48.10 |
| CER (β) | 23.47 |
Best validation accuracy during training: 0.9825.
Citation
@inproceedings{feng2026cld,
title = {Convex Low-resource Accent-Robust Language Detection in Speech Recognition},
author = {Feng, Miria and Tan, William and Pilanci, Mert},
booktitle = {Proceedings of the 43rd International Conference on Machine Learning},
year = {2026},
series = {Proceedings of Machine Learning Research},
publisher = {PMLR},
url = {https://icml.cc/virtual/2026/poster/64615}
}
License
MIT β see the CLD repository.
Model tree for williamhtan/cld-mms-1b-5lang
Base model
facebook/mms-1b-all