granite-speech-webgpu / punctuator.js
gsaon's picture
Upload 2 files
25633dd verified
/**
* Punctuation and Capitalization using ONNX
* - English: Full punctuation + capitalization (1-800-BAD-CODE model)
* - Other languages (DE, FR, IT, NL, ES, PT): Punctuation only (oliverguhr multilingual model)
*/
// English model (punctuation + capitalization)
let pcsSession = null;
let pcsVocab = null;
let pcsVocabReverse = null;
// Multilingual model (punctuation only)
let multilingualSession = null;
let multilingualTokenizer = null;
const PCS_CONFIG = {
preLabels: ["<NULL>", "¿"],
postLabels: ["<NULL>", "<ACRONYM>", ".", ",", "?"],
unkId: 0,
bosId: 1,
eosId: 2,
padId: 3,
};
// Multilingual model label mapping
const MULTILINGUAL_LABELS = {
0: "", // No punctuation
1: ".", // Period
2: ",", // Comma
3: "?", // Question mark
4: "-", // Hyphen
5: ":", // Colon
};
// Languages supported by multilingual model
const MULTILINGUAL_LANGS = ['de', 'fr', 'it', 'nl', 'es', 'pt'];
// Load the English punctuator model and vocab
async function cachedFetch(url) {
const cache = await caches.open('granite-speech-local-models');
const cached = await cache.match(url);
if (cached) return cached;
const response = await fetch(url);
if (response.ok) await cache.put(url, response.clone());
return response;
}
async function loadEnglishPunctuator() {
if (pcsSession) return;
console.log('Loading English punctuator model...');
// Load vocab
const vocabResponse = await cachedFetch('./pcs_vocab.json');
const vocabData = await vocabResponse.json();
pcsVocab = vocabData.vocab;
// Create reverse vocab (id -> piece)
pcsVocabReverse = {};
for (const [piece, id] of Object.entries(pcsVocab)) {
pcsVocabReverse[id] = piece;
}
// Load ONNX model
const modelResponse = await cachedFetch('./punct_cap_seg_en.onnx');
const buffer = await modelResponse.arrayBuffer();
pcsSession = await ort.InferenceSession.create(buffer, {
executionProviders: ['wasm'],
});
console.log('English punctuator model loaded');
}
// Load the multilingual punctuator model
async function loadMultilingualPunctuator() {
if (multilingualSession) return;
console.log('Loading multilingual punctuator model...');
// Load tokenizer from transformers.js
const { AutoTokenizer } = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.4.2');
multilingualTokenizer = await AutoTokenizer.from_pretrained('oliverguhr/fullstop-punctuation-multilingual-base');
// Load ONNX model
const modelResponse = await cachedFetch('./punct_multilingual_q8.onnx');
const buffer = await modelResponse.arrayBuffer();
multilingualSession = await ort.InferenceSession.create(buffer, {
executionProviders: ['wasm'],
});
console.log('Multilingual punctuator model loaded');
}
// Simple Unigram tokenizer for English model (greedy longest match)
function tokenizeEnglish(text) {
const normalized = text.toLowerCase().replace(/ /g, '▁');
const tokens = [];
let i = 0;
// Add BOS
tokens.push(PCS_CONFIG.bosId);
// Prepend ▁ for first word
let remaining = '▁' + normalized;
while (remaining.length > 0) {
let found = false;
// Try longest match first
for (let len = Math.min(remaining.length, 20); len > 0; len--) {
const piece = remaining.substring(0, len);
if (pcsVocab[piece] !== undefined) {
tokens.push(pcsVocab[piece]);
remaining = remaining.substring(len);
found = true;
break;
}
}
if (!found) {
// Unknown character, use UNK and skip
tokens.push(PCS_CONFIG.unkId);
remaining = remaining.substring(1);
}
}
// Add EOS
tokens.push(PCS_CONFIG.eosId);
return tokens;
}
// Apply punctuation and capitalization for English
async function applyEnglishPunctuation(text) {
await loadEnglishPunctuator();
// Tokenize
const tokenIds = tokenizeEnglish(text);
// Run inference
const inputTensor = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(BigInt)), [1, tokenIds.length]);
const outputs = await pcsSession.run({ input_ids: inputTensor });
const prePreds = outputs.pre_preds.data;
const postPreds = outputs.post_preds.data;
const capPreds = outputs.cap_preds.data;
const segPreds = outputs.seg_preds.data;
// Decode: skip BOS (index 0) and EOS (last index)
const numTokens = tokenIds.length - 2;
const result = [];
let currentSentence = [];
for (let i = 0; i < numTokens; i++) {
const tokenId = tokenIds[i + 1];
const token = pcsVocabReverse[tokenId] || '';
const outputIdx = i + 1;
// Handle word boundary
if (token.startsWith('▁') && currentSentence.length > 0) {
currentSentence.push(' ');
}
// Process each character in token
const charStart = token.startsWith('▁') ? 1 : 0;
for (let j = charStart; j < token.length; j++) {
let char = token[j];
// Pre-punctuation (e.g., inverted question mark)
if (j === charStart && prePreds[outputIdx] === 1) {
currentSentence.push(PCS_CONFIG.preLabels[1]);
}
// Capitalization - capPreds is [batch, seq, 16]
const capOffset = outputIdx * 16 + j;
if (capPreds[capOffset]) {
char = char.toUpperCase();
}
currentSentence.push(char);
// Post-punctuation
const postLabel = postPreds[outputIdx];
if (postLabel === 1) { // ACRONYM
currentSentence.push('.');
} else if (j === token.length - 1 && postLabel > 1) {
currentSentence.push(PCS_CONFIG.postLabels[postLabel]);
}
}
// Sentence boundary
if (segPreds[outputIdx]) {
result.push(currentSentence.join(''));
currentSentence = [];
}
}
if (currentSentence.length > 0) {
result.push(currentSentence.join(''));
}
return result.join(' ');
}
// Apply punctuation only for other languages (multilingual model)
async function applyMultilingualPunctuation(text) {
await loadMultilingualPunctuator();
// Tokenize using transformers.js tokenizer
const encoded = await multilingualTokenizer(text, {
return_tensors: false,
padding: false,
truncation: true,
max_length: 512,
});
const inputIds = encoded.input_ids;
const attentionMask = encoded.attention_mask;
// Run inference
const inputIdsTensor = new ort.Tensor('int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length]);
const attentionMaskTensor = new ort.Tensor('int64', BigInt64Array.from(attentionMask.map(BigInt)), [1, attentionMask.length]);
const outputs = await multilingualSession.run({
input_ids: inputIdsTensor,
attention_mask: attentionMaskTensor,
});
const logits = outputs.logits.data;
const numLabels = 6;
// Get predictions (argmax over logits)
const predictions = [];
for (let i = 0; i < inputIds.length; i++) {
let maxIdx = 0;
let maxVal = logits[i * numLabels];
for (let j = 1; j < numLabels; j++) {
if (logits[i * numLabels + j] > maxVal) {
maxVal = logits[i * numLabels + j];
maxIdx = j;
}
}
predictions.push(maxIdx);
}
// Decode tokens back to text with punctuation
const tokens = multilingualTokenizer.model.convert_ids_to_tokens(inputIds);
const result = [];
for (let i = 0; i < tokens.length; i++) {
const token = tokens[i];
// Skip special tokens
if (token === '<s>' || token === '</s>' || token === '<pad>') {
continue;
}
// Handle subword tokens (▁ prefix indicates start of new word)
if (token.startsWith('▁')) {
if (result.length > 0) {
result.push(' ');
}
result.push(token.substring(1));
} else {
result.push(token);
}
// Add punctuation after token
const punct = MULTILINGUAL_LABELS[predictions[i]];
if (punct) {
result.push(punct);
}
}
return result.join('');
}
// Main entry point - routes to appropriate model based on language
async function applyPunctuation(text, lang = null) {
if (!text || text.trim().length === 0) return text;
// If language specified and supported by multilingual model, use it
if (lang && MULTILINGUAL_LANGS.includes(lang)) {
try {
return await applyMultilingualPunctuation(text);
} catch (error) {
console.warn('Multilingual punctuation failed, returning original:', error);
return text;
}
}
// Default to English model
try {
return await applyEnglishPunctuation(text);
} catch (error) {
console.warn('English punctuation failed, returning original:', error);
return text;
}
}
// Preload English model (called during init)
async function loadPunctuator() {
await loadEnglishPunctuator();
}
// Export for use in app.js
window.applyPunctuation = applyPunctuation;
window.loadPunctuator = loadPunctuator;
window.MULTILINGUAL_PUNCT_LANGS = MULTILINGUAL_LANGS;