| |
| |
| |
| |
| |
|
|
| |
| let pcsSession = null; |
| let pcsVocab = null; |
| let pcsVocabReverse = null; |
|
|
| |
| let multilingualSession = null; |
| let multilingualTokenizer = null; |
|
|
| const PCS_CONFIG = { |
| preLabels: ["<NULL>", "¿"], |
| postLabels: ["<NULL>", "<ACRONYM>", ".", ",", "?"], |
| unkId: 0, |
| bosId: 1, |
| eosId: 2, |
| padId: 3, |
| }; |
|
|
| |
| const MULTILINGUAL_LABELS = { |
| 0: "", |
| 1: ".", |
| 2: ",", |
| 3: "?", |
| 4: "-", |
| 5: ":", |
| }; |
|
|
| |
| const MULTILINGUAL_LANGS = ['de', 'fr', 'it', 'nl', 'es', 'pt']; |
|
|
| |
| async function cachedFetch(url) { |
| const cache = await caches.open('granite-speech-local-models'); |
| const cached = await cache.match(url); |
| if (cached) return cached; |
| const response = await fetch(url); |
| if (response.ok) await cache.put(url, response.clone()); |
| return response; |
| } |
|
|
| async function loadEnglishPunctuator() { |
| if (pcsSession) return; |
|
|
| console.log('Loading English punctuator model...'); |
|
|
| |
| const vocabResponse = await cachedFetch('./pcs_vocab.json'); |
| const vocabData = await vocabResponse.json(); |
| pcsVocab = vocabData.vocab; |
|
|
| |
| pcsVocabReverse = {}; |
| for (const [piece, id] of Object.entries(pcsVocab)) { |
| pcsVocabReverse[id] = piece; |
| } |
|
|
| |
| const modelResponse = await cachedFetch('./punct_cap_seg_en.onnx'); |
| const buffer = await modelResponse.arrayBuffer(); |
| pcsSession = await ort.InferenceSession.create(buffer, { |
| executionProviders: ['wasm'], |
| }); |
|
|
| console.log('English punctuator model loaded'); |
| } |
|
|
| |
| async function loadMultilingualPunctuator() { |
| if (multilingualSession) return; |
|
|
| console.log('Loading multilingual punctuator model...'); |
|
|
| |
| const { AutoTokenizer } = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.4.2'); |
| multilingualTokenizer = await AutoTokenizer.from_pretrained('oliverguhr/fullstop-punctuation-multilingual-base'); |
|
|
| |
| const modelResponse = await cachedFetch('./punct_multilingual_q8.onnx'); |
| const buffer = await modelResponse.arrayBuffer(); |
| multilingualSession = await ort.InferenceSession.create(buffer, { |
| executionProviders: ['wasm'], |
| }); |
|
|
| console.log('Multilingual punctuator model loaded'); |
| } |
|
|
| |
| function tokenizeEnglish(text) { |
| const normalized = text.toLowerCase().replace(/ /g, '▁'); |
| const tokens = []; |
| let i = 0; |
|
|
| |
| tokens.push(PCS_CONFIG.bosId); |
|
|
| |
| let remaining = '▁' + normalized; |
|
|
| while (remaining.length > 0) { |
| let found = false; |
| |
| for (let len = Math.min(remaining.length, 20); len > 0; len--) { |
| const piece = remaining.substring(0, len); |
| if (pcsVocab[piece] !== undefined) { |
| tokens.push(pcsVocab[piece]); |
| remaining = remaining.substring(len); |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| |
| tokens.push(PCS_CONFIG.unkId); |
| remaining = remaining.substring(1); |
| } |
| } |
|
|
| |
| tokens.push(PCS_CONFIG.eosId); |
|
|
| return tokens; |
| } |
|
|
| |
| async function applyEnglishPunctuation(text) { |
| await loadEnglishPunctuator(); |
|
|
| |
| const tokenIds = tokenizeEnglish(text); |
|
|
| |
| const inputTensor = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(BigInt)), [1, tokenIds.length]); |
| const outputs = await pcsSession.run({ input_ids: inputTensor }); |
|
|
| const prePreds = outputs.pre_preds.data; |
| const postPreds = outputs.post_preds.data; |
| const capPreds = outputs.cap_preds.data; |
| const segPreds = outputs.seg_preds.data; |
|
|
| |
| const numTokens = tokenIds.length - 2; |
| const result = []; |
| let currentSentence = []; |
|
|
| for (let i = 0; i < numTokens; i++) { |
| const tokenId = tokenIds[i + 1]; |
| const token = pcsVocabReverse[tokenId] || ''; |
| const outputIdx = i + 1; |
|
|
| |
| if (token.startsWith('▁') && currentSentence.length > 0) { |
| currentSentence.push(' '); |
| } |
|
|
| |
| const charStart = token.startsWith('▁') ? 1 : 0; |
| for (let j = charStart; j < token.length; j++) { |
| let char = token[j]; |
|
|
| |
| if (j === charStart && prePreds[outputIdx] === 1) { |
| currentSentence.push(PCS_CONFIG.preLabels[1]); |
| } |
|
|
| |
| const capOffset = outputIdx * 16 + j; |
| if (capPreds[capOffset]) { |
| char = char.toUpperCase(); |
| } |
|
|
| currentSentence.push(char); |
|
|
| |
| const postLabel = postPreds[outputIdx]; |
| if (postLabel === 1) { |
| currentSentence.push('.'); |
| } else if (j === token.length - 1 && postLabel > 1) { |
| currentSentence.push(PCS_CONFIG.postLabels[postLabel]); |
| } |
| } |
|
|
| |
| if (segPreds[outputIdx]) { |
| result.push(currentSentence.join('')); |
| currentSentence = []; |
| } |
| } |
|
|
| if (currentSentence.length > 0) { |
| result.push(currentSentence.join('')); |
| } |
|
|
| return result.join(' '); |
| } |
|
|
| |
| async function applyMultilingualPunctuation(text) { |
| await loadMultilingualPunctuator(); |
|
|
| |
| const encoded = await multilingualTokenizer(text, { |
| return_tensors: false, |
| padding: false, |
| truncation: true, |
| max_length: 512, |
| }); |
|
|
| const inputIds = encoded.input_ids; |
| const attentionMask = encoded.attention_mask; |
|
|
| |
| const inputIdsTensor = new ort.Tensor('int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length]); |
| const attentionMaskTensor = new ort.Tensor('int64', BigInt64Array.from(attentionMask.map(BigInt)), [1, attentionMask.length]); |
|
|
| const outputs = await multilingualSession.run({ |
| input_ids: inputIdsTensor, |
| attention_mask: attentionMaskTensor, |
| }); |
|
|
| const logits = outputs.logits.data; |
| const numLabels = 6; |
|
|
| |
| const predictions = []; |
| for (let i = 0; i < inputIds.length; i++) { |
| let maxIdx = 0; |
| let maxVal = logits[i * numLabels]; |
| for (let j = 1; j < numLabels; j++) { |
| if (logits[i * numLabels + j] > maxVal) { |
| maxVal = logits[i * numLabels + j]; |
| maxIdx = j; |
| } |
| } |
| predictions.push(maxIdx); |
| } |
|
|
| |
| const tokens = multilingualTokenizer.model.convert_ids_to_tokens(inputIds); |
| const result = []; |
|
|
| for (let i = 0; i < tokens.length; i++) { |
| const token = tokens[i]; |
|
|
| |
| if (token === '<s>' || token === '</s>' || token === '<pad>') { |
| continue; |
| } |
|
|
| |
| if (token.startsWith('▁')) { |
| if (result.length > 0) { |
| result.push(' '); |
| } |
| result.push(token.substring(1)); |
| } else { |
| result.push(token); |
| } |
|
|
| |
| const punct = MULTILINGUAL_LABELS[predictions[i]]; |
| if (punct) { |
| result.push(punct); |
| } |
| } |
|
|
| return result.join(''); |
| } |
|
|
| |
| async function applyPunctuation(text, lang = null) { |
| if (!text || text.trim().length === 0) return text; |
|
|
| |
| if (lang && MULTILINGUAL_LANGS.includes(lang)) { |
| try { |
| return await applyMultilingualPunctuation(text); |
| } catch (error) { |
| console.warn('Multilingual punctuation failed, returning original:', error); |
| return text; |
| } |
| } |
|
|
| |
| try { |
| return await applyEnglishPunctuation(text); |
| } catch (error) { |
| console.warn('English punctuation failed, returning original:', error); |
| return text; |
| } |
| } |
|
|
| |
| async function loadPunctuator() { |
| await loadEnglishPunctuator(); |
| } |
|
|
| |
| window.applyPunctuation = applyPunctuation; |
| window.loadPunctuator = loadPunctuator; |
| window.MULTILINGUAL_PUNCT_LANGS = MULTILINGUAL_LANGS; |
|
|