Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import uuid | |
| # Disable PyTorch dynamo/inductor globally | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| os.environ["TORCHINDUCTOR_DISABLE"] = "1" | |
| import torch._dynamo as dynamo | |
| dynamo.config.suppress_errors = True | |
| import json | |
| from pathlib import Path | |
| import nltk | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from voxtream.generator import SpeechGenerator, SpeechGeneratorConfig | |
| with open("configs/generator.json") as f: | |
| config = SpeechGeneratorConfig(**json.load(f)) | |
| # Loading speaker encoder | |
| torch.hub.load( | |
| config.spk_enc_repo, | |
| config.spk_enc_model, | |
| model_name=config.spk_enc_model_name, | |
| train_type=config.spk_enc_train_type, | |
| dataset=config.spk_enc_dataset, | |
| trust_repo=True, | |
| verbose=False, | |
| ) | |
| # Loading NLTK packages | |
| nltk.download("averaged_perceptron_tagger_eng", quiet=True, raise_on_error=True) | |
| nltk.download("punkt", quiet=True, raise_on_error=True) | |
| # Initialize speech generator | |
| speech_generator = SpeechGenerator(config) | |
| FADE_OUT_SEC = 0.10 | |
| MIN_CHUNK_SEC = 0.2 | |
| CHUNK_SIZE = int(config.mimi_sr * MIN_CHUNK_SEC) | |
| CUSTOM_CSS = """ | |
| /* overall width */ | |
| .gradio-container {max-width: 1100px !important} | |
| /* stack labels tighter and even heights */ | |
| #cols .wrap > .form {gap: 10px} | |
| #left-col, #right-col {gap: 14px} | |
| /* make submit centered + bigger */ | |
| #submit {width: 260px; margin: 10px auto 0 auto;} | |
| /* make clear align left and look secondary */ | |
| #clear {width: 120px;} | |
| /* give audio a little breathing room */ | |
| audio {outline: none;} | |
| """ | |
| def float32_to_int16(audio_float32: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert float32 audio samples (-1.0 to 1.0) to int16 PCM samples. | |
| Parameters: | |
| audio_float32 (np.ndarray): Input float32 audio samples. | |
| Returns: | |
| np.ndarray: Output int16 audio samples. | |
| """ | |
| if audio_float32.dtype != np.float32: | |
| raise ValueError("Input must be a float32 numpy array") | |
| # Clip to avoid overflow after scaling | |
| audio_clipped = np.clip(audio_float32, -1.0, 1.0) | |
| # Scale and convert | |
| audio_int16 = (audio_clipped * 32767).astype(np.int16) | |
| return audio_int16 | |
| def _clear_outputs(): | |
| # clears the player + hides file (download btn mirrors file via .change) | |
| return None, gr.update(value=None, visible=False) | |
| def synthesize_fn(prompt_audio_path, prompt_text, target_text): | |
| if next(speech_generator.model.parameters()).device.type == "cpu": | |
| speech_generator.model.to("cuda") | |
| speech_generator.mimi.to("cuda") | |
| speech_generator.spk_enc.to("cuda") | |
| speech_generator.aligner.aligner.to("cuda") | |
| speech_generator.aligner.device = "cuda" | |
| speech_generator.device = "cuda" | |
| if not prompt_audio_path or not target_text: | |
| return None, gr.update(value=None, visible=False) | |
| stream = speech_generator.generate_stream( | |
| prompt_text=prompt_text, | |
| prompt_audio_path=Path(prompt_audio_path), | |
| text=target_text, | |
| ) | |
| buffer = [] | |
| buffer_len = 0 | |
| total_buffer = [] | |
| for frame, _ in stream: | |
| buffer.append(frame) | |
| total_buffer.append(frame) | |
| buffer_len += frame.shape[0] | |
| if buffer_len >= CHUNK_SIZE: | |
| audio = np.concatenate(buffer) | |
| yield (config.mimi_sr, float32_to_int16(audio)), None | |
| # Reset buffer and length | |
| buffer = [] | |
| buffer_len = 0 | |
| # Handle any remaining audio in the buffer | |
| if buffer_len > 0: | |
| final = np.concatenate(buffer) | |
| nfade = min(int(config.mimi_sr * FADE_OUT_SEC), final.shape[0]) | |
| if nfade > 0: | |
| fade = np.linspace(1.0, 0.0, nfade, dtype=np.float32) | |
| final[-nfade:] *= fade | |
| yield (config.mimi_sr, float32_to_int16(final)), None | |
| # Save the full audio to a file for download | |
| if len(total_buffer) > 0: | |
| full_audio = np.concatenate(total_buffer) | |
| nfade = min(int(config.mimi_sr * FADE_OUT_SEC), full_audio.shape[0]) | |
| if nfade > 0: | |
| fade = np.linspace(1.0, 0.0, nfade, dtype=np.float32) | |
| full_audio[-nfade:] *= fade | |
| file_path = f"/tmp/voxtream_{uuid.uuid4().hex}.wav" | |
| sf.write(file_path, float32_to_int16(full_audio), config.mimi_sr) | |
| yield None, gr.update(value=file_path, visible=True) | |
| else: | |
| yield None, gr.update(value=None, visible=False) | |
| def main(): | |
| with gr.Blocks(css=CUSTOM_CSS, title="VoXtream") as demo: | |
| gr.Markdown("# VoXtream TTS demo") | |
| gr.Markdown("⚠️ The initial latency can be high due to deployment on ZeroGPU. For faster inference, please try local deployment. For more details, please visit [VoXtream GitHub repo](https://github.com/herimor/voxtream)") | |
| with gr.Row(equal_height=True, elem_id="cols"): | |
| with gr.Column(scale=1, elem_id="left-col"): | |
| prompt_audio = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Prompt audio (3-5 sec of target voice. Max 10 sec)", | |
| ) | |
| prompt_text = gr.Textbox( | |
| lines=3, | |
| max_length=config.max_prompt_chars, | |
| label=f"Prompt transcript (Required, max {config.max_prompt_chars} chars)", | |
| placeholder="Text that matches the prompt audio", | |
| ) | |
| with gr.Column(scale=1, elem_id="right-col"): | |
| target_text = gr.Textbox( | |
| lines=3, | |
| max_length=config.max_phone_tokens, | |
| label=f"Target text (Required, max {config.max_phone_tokens} chars)", | |
| placeholder="What you want the model to say", | |
| ) | |
| output_audio = gr.Audio( | |
| label="Synthesized audio", | |
| interactive=False, | |
| streaming=True, | |
| autoplay=True, | |
| show_download_button=False, | |
| show_share_button=False, | |
| ) | |
| # appears only when file is ready | |
| download_btn = gr.DownloadButton( | |
| "Download audio", | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", elem_id="clear", variant="secondary") | |
| submit_btn = gr.Button("Submit", elem_id="submit", variant="primary") | |
| # Message box for validation errors | |
| validation_msg = gr.Markdown("", visible=False) | |
| # --- Validation logic --- | |
| def validate_inputs(audio, ptext, ttext): | |
| if not audio: | |
| return gr.update(visible=True, value="⚠️ Please provide a prompt audio."), gr.update(interactive=False) | |
| if not ptext.strip(): | |
| return gr.update(visible=True, value="⚠️ Please provide a prompt transcript."), gr.update(interactive=False) | |
| if not ttext.strip(): | |
| return gr.update(visible=True, value="⚠️ Please provide target text."), gr.update(interactive=False) | |
| return gr.update(visible=False, value=""), gr.update(interactive=True) | |
| # Live validation whenever inputs change | |
| for inp in [prompt_audio, prompt_text, target_text]: | |
| inp.change( | |
| fn=validate_inputs, | |
| inputs=[prompt_audio, prompt_text, target_text], | |
| outputs=[validation_msg, submit_btn], | |
| ) | |
| # clear outputs before streaming | |
| submit_btn.click( | |
| fn=lambda a, p, t: (None, gr.update(value=None, visible=False)), | |
| inputs=[prompt_audio, prompt_text, target_text], | |
| outputs=[output_audio, download_btn], | |
| show_progress="hidden", | |
| ).then( | |
| fn=synthesize_fn, | |
| inputs=[prompt_audio, prompt_text, target_text], | |
| outputs=[output_audio, download_btn], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ( | |
| None, "", "", # inputs | |
| None, # output_audio | |
| gr.update(value=None, visible=False), # download_btn | |
| gr.update(visible=False, value=""), # validation_msg | |
| gr.update(interactive=False), # submit_btn | |
| ), | |
| inputs=[], | |
| outputs=[prompt_audio, prompt_text, target_text, output_audio, download_btn, validation_msg, submit_btn], | |
| ) | |
| # --- Add Examples --- | |
| gr.Markdown("### Examples") | |
| ex = gr.Examples( | |
| examples=[ | |
| [ | |
| "assets/app/male.wav", | |
| "You could take the easy route or a situation that makes sense which a lot of you do", | |
| "Hey, how are you doing? I just uhm want to make sure everything is okay." | |
| ], | |
| [ | |
| "assets/app/female.wav", | |
| "I would certainly anticipate some pushback whereas most people know if you followed my work.", | |
| "Hello, hello. Let's have a quick chat, uh, in an hour. I need to share something with you." | |
| ], | |
| ], | |
| inputs=[prompt_audio, prompt_text, target_text], | |
| outputs=[output_audio, download_btn], | |
| fn=synthesize_fn, | |
| cache_examples=False, | |
| ) | |
| ex.dataset.click( | |
| fn=_clear_outputs, | |
| inputs=[], | |
| outputs=[output_audio, download_btn], | |
| queue=False, | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |