/** * Granite Speech WebGPU Demo * Uses Transformers.js v4 for in-browser speech recognition */ import { AutoProcessor, GraniteSpeechForConditionalGeneration, TextStreamer, Tensor, } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.2.0'; // Model const MODEL_ID = 'onnx-community/granite-speech-4.1-2b-ONNX'; // Audio config const SAMPLE_RATE = 16000; const MAX_NEW_TOKENS = 256; // Task prompts — <|audio|> is expanded by the processor's chat template. // Granite Speech 4.1 2B produces punctuation and capitalization natively when asked. const TASK_PROMPTS = { 'transcribe': '<|audio|>Transcribe the speech to text with proper punctuation and capitalization', 'translate_en': '<|audio|>Translate the speech to English with proper punctuation and capitalization', 'translate_fr': '<|audio|>Translate the speech to French with proper punctuation and capitalization', 'translate_de': '<|audio|>Translate the speech to German with proper punctuation and capitalization', 'translate_es': '<|audio|>Translate the speech to Spanish with proper punctuation and capitalization', 'translate_pt': '<|audio|>Translate the speech to Portuguese with proper punctuation and capitalization', 'translate_ja': '<|audio|>Translate the speech to Japanese with proper punctuation and capitalization', }; // State let model = null; let processor = null; let isModelLoading = false; let currentAudioData = null; // DOM Elements const statusDot = document.getElementById('statusDot'); const statusText = document.getElementById('statusText'); const recordBtn = document.getElementById('recordBtn'); const audioFile = document.getElementById('audioFile'); const fileTile = document.querySelector('.file-label'); const inputCard = document.querySelector('.input-card'); const audioPreview = document.getElementById('audioPreview'); const audioPlayer = document.getElementById('audioPlayer'); const playBtn = document.getElementById('playBtn'); const waveformCanvas = document.getElementById('waveformCanvas'); const waveformProgress = document.getElementById('waveformProgress'); const audioTime = document.getElementById('audioTime'); const transcribeSection = document.getElementById('transcribeSection'); const transcribeBtn = document.getElementById('transcribeBtn'); const promptSelect = document.getElementById('promptSelect'); const transcriptCard = document.getElementById('transcriptCard'); const outputText = document.getElementById('outputText'); const copyBtn = document.getElementById('copyBtn'); const downloadBtn = document.getElementById('downloadBtn'); const clearBtn = document.getElementById('clearBtn'); const progressSection = document.getElementById('progressSection'); const progressFill = document.getElementById('progressFill'); const progressText = document.getElementById('progressText'); const vadCheckbox = document.getElementById('vadCheckbox'); const gpuInfo = document.getElementById('gpuInfo'); // Recording state let mediaRecorder = null; let audioChunks = []; let transcriptionAborted = false; // Utility functions function setStatus(status, message) { statusDot.className = `status-dot ${status}`; statusText.textContent = message; } function showProgress(show) { progressSection.style.display = show ? 'block' : 'none'; } function updateProgress(progress, text) { progressFill.style.width = `${progress}%`; progressText.textContent = text; } // Check WebGPU support async function checkWebGPU() { if (!navigator.gpu) { gpuInfo.textContent = 'WebGPU not supported. Use Chrome 113+ or Edge 113+'; gpuInfo.style.color = '#e74c3c'; return false; } try { const adapter = await navigator.gpu.requestAdapter(); if (!adapter) { gpuInfo.textContent = 'No WebGPU adapter available'; gpuInfo.style.color = '#f39c12'; return false; } return true; } catch (e) { console.error('WebGPU error:', e); gpuInfo.textContent = `WebGPU error: ${e.message || e}`; gpuInfo.style.color = '#e74c3c'; return false; } } // Initialize models using Transformers.js v4 async function initModels() { if (isModelLoading) return; isModelLoading = true; setStatus('loading', 'Loading processor...'); try { await checkWebGPU(); processor = await AutoProcessor.from_pretrained(MODEL_ID); setStatus('loading', 'Downloading models...'); progressFill.style.width = '0%'; let lastProgressUpdate = 0; const fileProgress = {}; model = await GraniteSpeechForConditionalGeneration.from_pretrained(MODEL_ID, { dtype: { audio_encoder: 'q4', embed_tokens: 'q4f16', decoder_model_merged: 'q4f16', }, device: 'webgpu', progress_callback: (progress) => { if (progress.status === 'progress' && progress.total) { fileProgress[progress.file] = { loaded: progress.loaded, total: progress.total }; const now = performance.now(); if (now - lastProgressUpdate < 100) return; lastProgressUpdate = now; let totalLoaded = 0, totalSize = 0; for (const f of Object.values(fileProgress)) { totalLoaded += f.loaded; totalSize += f.total; } const pct = totalSize > 0 ? (totalLoaded / totalSize) * 100 : 0; progressFill.style.width = `${pct}%`; const mb = (totalLoaded / 1e6).toFixed(0); const totalMb = (totalSize / 1e6).toFixed(0); setStatus('loading', `Downloading models... ${mb} / ${totalMb} MB`); } }, }); // TEMP: tolerate token/feature count mismatches from the onnx-community // 4.1-2b encoder export by truncating/padding encoder features to match // the reserved token slots. Remove once upstream export + PR #1685 ship. const _origMerge = model._merge_input_ids_with_audio_features.bind(model); const patchedMerge = function(kwargs) { if (!kwargs?.audio_features) return _origMerge(kwargs); const audio_token_id = this.config.ignore_index ?? this.config.audio_token_id ?? this.config.audio_token_index; const ids = kwargs.input_ids.tolist().flat(); const n_tokens = ids.filter(x => Number(x) === Number(audio_token_id)).length; const hidden = kwargs.audio_features.dims.at(-1); let n_features = 1; for (let i = 0; i < kwargs.audio_features.dims.length - 1; i++) { n_features *= kwargs.audio_features.dims[i]; } if (n_features === n_tokens) return _origMerge(kwargs); console.warn(`[shim] audio features/tokens mismatch: features=${n_features}, tokens=${n_tokens}, adjusting`); const flat = kwargs.audio_features.view(-1, hidden); const src = flat.data; let dst; if (n_features > n_tokens) { dst = src.slice(0, n_tokens * hidden); } else { dst = new src.constructor(n_tokens * hidden); if (n_features > 0) { dst.set(src); const lastStart = (n_features - 1) * hidden; const lastVec = src.subarray(lastStart, lastStart + hidden); for (let i = n_features; i < n_tokens; i++) { dst.set(lastVec, i * hidden); } } } return _origMerge({ ...kwargs, audio_features: new Tensor(flat.type, dst, [n_tokens, hidden]), }); }; model._merge_input_ids_with_audio_features = patchedMerge; Object.getPrototypeOf(model)._merge_input_ids_with_audio_features = patchedMerge; setStatus('loading', 'Loading VAD model...'); await loadVAD(); progressFill.style.width = '0%'; setStatus('ready', 'Ready - Record or upload audio'); enableControls(true); } catch (error) { console.error('Model loading failed:', error); console.error('Error stack:', error?.stack); const errorMsg = error?.message || error?.toString() || 'Unknown error'; setStatus('error', `Error: ${errorMsg}`); progressFill.style.width = '0%'; isModelLoading = false; } } function enableControls(enabled) { recordBtn.disabled = !enabled; audioFile.disabled = !enabled; } // Transcribe a single audio segment and return the text async function transcribeSegment(audioSegment, onPartialResult) { // Build prompt using chat template const taskKey = promptSelect.value; const content = TASK_PROMPTS[taskKey] || TASK_PROMPTS['transcribe']; const messages = [{ role: 'user', content }]; const text = processor.tokenizer.apply_chat_template(messages, { add_generation_prompt: true, tokenize: false, }); // Process text + audio into model inputs const inputs = await processor(text, audioSegment, { sampling_rate: SAMPLE_RATE }); // Streaming via TextStreamer let accumulated = ''; const streamer = new TextStreamer(processor.tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: (chunk) => { accumulated += chunk.replace(/"/g, ''); if (onPartialResult) { onPartialResult(accumulated); } }, }); // Generate await model.generate({ ...inputs, max_new_tokens: MAX_NEW_TOKENS, streamer, }); return accumulated; } // Wait until audio playback reaches a specific time function waitForPlaybackTime(targetTime) { return new Promise((resolve) => { const check = () => { if (audioPlayer.paused || audioPlayer.currentTime >= targetTime) { resolve(); } else { requestAnimationFrame(check); } }; check(); }); } // Run inference with segmentation and audio sync async function transcribe() { if (!model || !processor || !currentAudioData) { setStatus('error', 'Model or audio not ready'); return; } setStatus('processing', 'Processing audio...'); transcribeBtn.disabled = true; transcriptionAborted = false; outputText.textContent = ''; transcriptCard.style.display = 'block'; showProgress(true); try { // Get speech segments using VAD, or treat entire audio as one segment let segments; if (vadCheckbox.checked) { updateProgress(5, 'Detecting speech segments...'); segments = await getSpeechSegments(currentAudioData, SAMPLE_RATE); console.log(`VAD found ${segments.length} segment(s)`); } else { segments = [{ start: 0, end: currentAudioData.length / SAMPLE_RATE }]; } // Start audio playback immediately audioPlayer.currentTime = 0; audioPlayer.play(); playBtn.querySelector('.play-icon').style.display = 'none'; playBtn.querySelector('.pause-icon').style.display = 'block'; const playbackStartTime = performance.now() / 1000; // Process and display segments in sync with audio const displayedResults = []; const totalSegments = segments.length; for (let segIdx = 0; segIdx < totalSegments; segIdx++) { if (transcriptionAborted) break; const seg = segments[segIdx]; // Update progress bar const segProgress = ((segIdx + 1) / totalSegments) * 100; updateProgress(segProgress, ''); // Wait for audio to reach this segment's start time const elapsed = (performance.now() / 1000) - playbackStartTime; const waitTime = seg.start - elapsed; if (waitTime > 0) { await new Promise(resolve => setTimeout(resolve, waitTime * 1000)); } setStatus('processing', `Segment ${segIdx + 1}/${totalSegments}`); // Extract and transcribe this segment const startSample = Math.floor(seg.start * SAMPLE_RATE); const endSample = Math.floor(seg.end * SAMPLE_RATE); const audioSegment = currentAudioData.slice(startSample, endSample); const timestamp = formatTimestamp(seg.start); const makeRow = (ts, text) => `