--- language: - sa license: apache-2.0 base_model: openai/whisper-large-v2 tags: - automatic-speech-recognition - whisper - sanskrit - asr - deepspeed - fsdp - fine-tuned datasets: - pavanmantha/sanskrit_asr metrics: - wer pipeline_tag: automatic-speech-recognition model-index: - name: whisper-medium-sa results: - task: type: automatic-speech-recognition name: Automatic Speech Recognition dataset: name: pavanmantha/sanskrit_asr type: pavanmantha/sanskrit_asr split: validation metrics: - type: wer value: TBD name: WER --- # Whisper Large-v2 — Sanskrit ASR (Fine-Tuned) Fine-tuned version of [`openai/whisper-large-v2`](https://huggingface.co/openai/whisper-large-v2) on the [`pavanmantha/sanskrit_asr`](https://huggingface.co/datasets/pavanmantha/sanskrit_asr) dataset for Sanskrit (`sa`) automatic speech recognition using HuggingFace `Seq2SeqTrainer` + **DeepSpeed ZeRO-3** via Accelerate on 5× NVIDIA A10G GPUs. --- ## Model Details | Property | Value | |---|---| | **Base model** | `openai/whisper-large-v2` | | **Language** | Sanskrit (`sa`) — Devanagari script | | **Task** | Automatic Speech Recognition (`transcribe`) | | **Fine-tuning framework** | HuggingFace Transformers + DeepSpeed ZeRO-3 | | **Precision** | bf16 (Ampere / sm_86 native) | | **Parameters** | ~1.5B | | **License** | Apache 2.0 | --- ## Training Details ### Dataset | | Value | |---|---| | **Dataset** | [`pavanmantha/sanskrit_asr`](https://huggingface.co/datasets/pavanmantha/sanskrit_asr) | | **Train split** | ~95% of full dataset (5% held out for validation, seed=42) | | **Validation split** | ~5% of full dataset | | **Audio column** | `audio` (resampled to 16 kHz) | | **Text column** | `sentence` | ### Hardware & Infrastructure | | Value | |---|---| | **GPUs** | 5× NVIDIA A10G (22.5 GB VRAM each) | | **Instance type** | AWS G5 (or equivalent) | | **Distributed strategy** | DeepSpeed ZeRO Stage 3 via Accelerate | | **ZeRO-3 flags** | `stage3_gather_16bit_weights_on_model_save: true`, `overlap_comm: true`, `contiguous_gradients: true` | ### Hyperparameters | Parameter | Value | |---|---| | **Epochs** | 3 | | **Per-device batch size** | 8 | | **Gradient accumulation steps** | 4 | | **Effective batch size** | 160 (`8 × 4 × 5 GPUs`) | | **Learning rate** | 5e-6 | | **LR scheduler** | Linear with warmup | | **Warmup steps** | 200 | | **Weight decay** | 0.01 | | **Max grad norm** | 1.0 | | **Eval beam size** | 5 | | **Generation max length** | 225 tokens | | **Eval & save frequency** | Every 500 steps | | **Best model metric** | WER (lower is better) | ### Text Normalization (WER) WER is computed after applying a custom Sanskrit/Devanagari normalizer: - **NFC** Unicode recomposition (fixes encoding variants) - Remove danda (`।`), double-danda (`॥`), and common ASCII punctuation - Remove nukta (`U+093C`): normalizes ज़→ज, फ़→फ - Chandrabindu → Anusvara (`U+0901` → `U+0902`) - Collapse whitespace ### Key Training Flags - `gradient_checkpointing=True` with `use_reentrant=False` (required for ZeRO-3 compatibility) - `forced_decoder_ids=None` and `suppress_tokens=[]` to unblock Devanagari character tokens during generation - `predict_with_generate=True` with beam search (`num_beams=5`) for evaluation --- ## Usage ### Quick Inference ```python from transformers import pipeline asr = pipeline( "automatic-speech-recognition", model="pavanmantha/whisper-medium-sa", generate_kwargs={"language": "sanskrit", "task": "transcribe"} ) result = asr("path/to/audio.mp3") print(result["text"]) ``` ### Manual Inference ```python import torch from transformers import WhisperForConditionalGeneration, WhisperProcessor model_id = "pavanmantha/whisper-medium-sa" processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained(model_id) model.eval() # Load your audio (must be 16 kHz mono) # audio_array: np.ndarray, shape (N,), dtype float32 inputs = processor( audio_array, sampling_rate=16000, return_tensors="pt" ) with torch.no_grad(): predicted_ids = model.generate( inputs["input_features"], language="sa", task="transcribe", num_beams=5 ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) print(transcription[0]) ``` --- ## Evaluation | Split | WER | |---|---| | Validation | TBD | > WER values are computed using the custom Devanagari normalizer described above. --- ## Files | File | Description | |---|---| | `model.safetensors` | Fine-tuned model weights | | `config.json` | Model architecture config | | `generation_config.json` | Generation defaults (`language=sa`, `task=transcribe`) | | `tokenizer.json` | Whisper tokenizer | | `tokenizer_config.json` | Tokenizer configuration | | `processor_config.json` | Processor configuration | | `vocab.json` | Vocabulary file | | `merges.txt` | BPE merge rules | --- ## Limitations - Optimized specifically for Sanskrit (`sa`) in Devanagari script; performance on other languages or scripts is not guaranteed. - Audio inputs longer than 30 seconds are truncated to the first 30 seconds by the Whisper feature extractor. - Best results on clean, single-speaker audio sampled at 16 kHz. --- ## Citation If you use this model, please cite the original Whisper paper: ```bibtex @misc{radford2022whisper, title={Robust Speech Recognition via Large-Scale Weak Supervision}, author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya}, year={2022}, eprint={2212.04356}, archivePrefix={arXiv} } ``` --- ## Author Fine-tuned by [Pavan Kumar Mantha](https://huggingface.co/pavanmantha) · [GitHub](https://github.com/pavanjava)