/** * Punctuation and Capitalization using ONNX * - English: Full punctuation + capitalization (1-800-BAD-CODE model) * - Other languages (DE, FR, IT, NL, ES, PT): Punctuation only (oliverguhr multilingual model) */ // English model (punctuation + capitalization) let pcsSession = null; let pcsVocab = null; let pcsVocabReverse = null; // Multilingual model (punctuation only) let multilingualSession = null; let multilingualTokenizer = null; const PCS_CONFIG = { preLabels: ["", "¿"], postLabels: ["", "", ".", ",", "?"], unkId: 0, bosId: 1, eosId: 2, padId: 3, }; // Multilingual model label mapping const MULTILINGUAL_LABELS = { 0: "", // No punctuation 1: ".", // Period 2: ",", // Comma 3: "?", // Question mark 4: "-", // Hyphen 5: ":", // Colon }; // Languages supported by multilingual model const MULTILINGUAL_LANGS = ['de', 'fr', 'it', 'nl', 'es', 'pt']; // Load the English punctuator model and vocab async function cachedFetch(url) { const cache = await caches.open('granite-speech-local-models'); const cached = await cache.match(url); if (cached) return cached; const response = await fetch(url); if (response.ok) await cache.put(url, response.clone()); return response; } async function loadEnglishPunctuator() { if (pcsSession) return; console.log('Loading English punctuator model...'); // Load vocab const vocabResponse = await cachedFetch('./pcs_vocab.json'); const vocabData = await vocabResponse.json(); pcsVocab = vocabData.vocab; // Create reverse vocab (id -> piece) pcsVocabReverse = {}; for (const [piece, id] of Object.entries(pcsVocab)) { pcsVocabReverse[id] = piece; } // Load ONNX model const modelResponse = await cachedFetch('./punct_cap_seg_en.onnx'); const buffer = await modelResponse.arrayBuffer(); pcsSession = await ort.InferenceSession.create(buffer, { executionProviders: ['wasm'], }); console.log('English punctuator model loaded'); } // Load the multilingual punctuator model async function loadMultilingualPunctuator() { if (multilingualSession) return; console.log('Loading multilingual punctuator model...'); // Load tokenizer from transformers.js const { AutoTokenizer } = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.4.2'); multilingualTokenizer = await AutoTokenizer.from_pretrained('oliverguhr/fullstop-punctuation-multilingual-base'); // Load ONNX model const modelResponse = await cachedFetch('./punct_multilingual_q8.onnx'); const buffer = await modelResponse.arrayBuffer(); multilingualSession = await ort.InferenceSession.create(buffer, { executionProviders: ['wasm'], }); console.log('Multilingual punctuator model loaded'); } // Simple Unigram tokenizer for English model (greedy longest match) function tokenizeEnglish(text) { const normalized = text.toLowerCase().replace(/ /g, '▁'); const tokens = []; let i = 0; // Add BOS tokens.push(PCS_CONFIG.bosId); // Prepend ▁ for first word let remaining = '▁' + normalized; while (remaining.length > 0) { let found = false; // Try longest match first for (let len = Math.min(remaining.length, 20); len > 0; len--) { const piece = remaining.substring(0, len); if (pcsVocab[piece] !== undefined) { tokens.push(pcsVocab[piece]); remaining = remaining.substring(len); found = true; break; } } if (!found) { // Unknown character, use UNK and skip tokens.push(PCS_CONFIG.unkId); remaining = remaining.substring(1); } } // Add EOS tokens.push(PCS_CONFIG.eosId); return tokens; } // Apply punctuation and capitalization for English async function applyEnglishPunctuation(text) { await loadEnglishPunctuator(); // Tokenize const tokenIds = tokenizeEnglish(text); // Run inference const inputTensor = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(BigInt)), [1, tokenIds.length]); const outputs = await pcsSession.run({ input_ids: inputTensor }); const prePreds = outputs.pre_preds.data; const postPreds = outputs.post_preds.data; const capPreds = outputs.cap_preds.data; const segPreds = outputs.seg_preds.data; // Decode: skip BOS (index 0) and EOS (last index) const numTokens = tokenIds.length - 2; const result = []; let currentSentence = []; for (let i = 0; i < numTokens; i++) { const tokenId = tokenIds[i + 1]; const token = pcsVocabReverse[tokenId] || ''; const outputIdx = i + 1; // Handle word boundary if (token.startsWith('▁') && currentSentence.length > 0) { currentSentence.push(' '); } // Process each character in token const charStart = token.startsWith('▁') ? 1 : 0; for (let j = charStart; j < token.length; j++) { let char = token[j]; // Pre-punctuation (e.g., inverted question mark) if (j === charStart && prePreds[outputIdx] === 1) { currentSentence.push(PCS_CONFIG.preLabels[1]); } // Capitalization - capPreds is [batch, seq, 16] const capOffset = outputIdx * 16 + j; if (capPreds[capOffset]) { char = char.toUpperCase(); } currentSentence.push(char); // Post-punctuation const postLabel = postPreds[outputIdx]; if (postLabel === 1) { // ACRONYM currentSentence.push('.'); } else if (j === token.length - 1 && postLabel > 1) { currentSentence.push(PCS_CONFIG.postLabels[postLabel]); } } // Sentence boundary if (segPreds[outputIdx]) { result.push(currentSentence.join('')); currentSentence = []; } } if (currentSentence.length > 0) { result.push(currentSentence.join('')); } return result.join(' '); } // Apply punctuation only for other languages (multilingual model) async function applyMultilingualPunctuation(text) { await loadMultilingualPunctuator(); // Tokenize using transformers.js tokenizer const encoded = await multilingualTokenizer(text, { return_tensors: false, padding: false, truncation: true, max_length: 512, }); const inputIds = encoded.input_ids; const attentionMask = encoded.attention_mask; // Run inference const inputIdsTensor = new ort.Tensor('int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length]); const attentionMaskTensor = new ort.Tensor('int64', BigInt64Array.from(attentionMask.map(BigInt)), [1, attentionMask.length]); const outputs = await multilingualSession.run({ input_ids: inputIdsTensor, attention_mask: attentionMaskTensor, }); const logits = outputs.logits.data; const numLabels = 6; // Get predictions (argmax over logits) const predictions = []; for (let i = 0; i < inputIds.length; i++) { let maxIdx = 0; let maxVal = logits[i * numLabels]; for (let j = 1; j < numLabels; j++) { if (logits[i * numLabels + j] > maxVal) { maxVal = logits[i * numLabels + j]; maxIdx = j; } } predictions.push(maxIdx); } // Decode tokens back to text with punctuation const tokens = multilingualTokenizer.model.convert_ids_to_tokens(inputIds); const result = []; for (let i = 0; i < tokens.length; i++) { const token = tokens[i]; // Skip special tokens if (token === '' || token === '' || token === '') { continue; } // Handle subword tokens (▁ prefix indicates start of new word) if (token.startsWith('▁')) { if (result.length > 0) { result.push(' '); } result.push(token.substring(1)); } else { result.push(token); } // Add punctuation after token const punct = MULTILINGUAL_LABELS[predictions[i]]; if (punct) { result.push(punct); } } return result.join(''); } // Main entry point - routes to appropriate model based on language async function applyPunctuation(text, lang = null) { if (!text || text.trim().length === 0) return text; // If language specified and supported by multilingual model, use it if (lang && MULTILINGUAL_LANGS.includes(lang)) { try { return await applyMultilingualPunctuation(text); } catch (error) { console.warn('Multilingual punctuation failed, returning original:', error); return text; } } // Default to English model try { return await applyEnglishPunctuation(text); } catch (error) { console.warn('English punctuation failed, returning original:', error); return text; } } // Preload English model (called during init) async function loadPunctuator() { await loadEnglishPunctuator(); } // Export for use in app.js window.applyPunctuation = applyPunctuation; window.loadPunctuator = loadPunctuator; window.MULTILINGUAL_PUNCT_LANGS = MULTILINGUAL_LANGS;