Spaces:
Build error
Build error
| import spaces | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| from diffusers import GGUFQuantizationConfig, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel | |
| from diffusers.utils import export_to_video | |
| from huggingface_hub import snapshot_download | |
| import torch | |
| from PIL import Image | |
| # Configuration | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.set_grad_enabled(False) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Load base model | |
| model_id = "hunyuanvideo-community/HunyuanVideo" | |
| base_path = f"/home/user/app/{model_id}" | |
| os.makedirs(base_path, exist_ok=True) | |
| snapshot_download(repo_id=model_id, local_dir=base_path) | |
| # Load transformer | |
| ckp_path = Path(base_path) | |
| gguf_filename = "hunyuan-video-t2v-720p-Q4_0.gguf" | |
| transformer_path = f"https://huggingface.co/city96/HunyuanVideo-gguf/blob/main/{gguf_filename}" | |
| transformer = HunyuanVideoTransformer3DModel.from_single_file( | |
| transformer_path, | |
| quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), | |
| torch_dtype=torch.bfloat16, | |
| ).to('cuda') | |
| # Initialize pipeline | |
| pipe = HunyuanVideoPipeline.from_pretrained( | |
| ckp_path, | |
| transformer=transformer, | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| # Configure VAE | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| pipe.vae.eval() | |
| # Available LORAs with display names | |
| LORA_CHOICES = [ | |
| ("stripe_v2.safetensors", "Stripe Style"), | |
| ("Top_Off.safetensors", "Top Off Effect"), | |
| ("huanyan_helper.safetensors", "Hunyuan Helper"), | |
| ("huanyan_helper_alpha.safetensors", "Hunyuan Alpha"), | |
| ("hunyuan-t-solo-v1.0.safetensors", "Solo Animation") | |
| ] | |
| # Load all LORAs with hunyuanvideo-lora adapter | |
| for weight_name, display_name in LORA_CHOICES: | |
| pipe.load_lora_weights( | |
| "Sergidev/TTV4ME", | |
| weight_name=weight_name, | |
| adapter_name=display_name.replace(" ", "_").lower(), | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| # Memory cleanup | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| def generate( | |
| prompt, | |
| image_input, | |
| height, | |
| width, | |
| num_frames, | |
| num_inference_steps, | |
| seed_value, | |
| fps, | |
| selected_loras, | |
| lora_weights, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| # Validate image resolution | |
| if image_input is not None: | |
| img = Image.open(image_input) | |
| if img.size != (width, height): | |
| raise gr.Error(f"Image resolution {img.size} must match video resolution ({width}x{height})") | |
| # Configure LORAs | |
| active_adapters = [lora[1].replace(" ", "_").lower() for lora in LORA_CHOICES if lora[1] in selected_loras] | |
| weights = [float(lora_weights[selected_loras.index(lora[1])]) for lora in LORA_CHOICES if lora[1] in selected_loras] | |
| pipe.set_adapters(active_adapters, weights) | |
| with torch.cuda.device(0): | |
| if seed_value == -1: | |
| seed_value = torch.randint(0, MAX_SEED, (1,)).item() | |
| generator = torch.Generator('cuda').manual_seed(seed_value) | |
| with torch.amp.autocast_mode.autocast('cuda', dtype=torch.bfloat16), torch.inference_mode(), torch.no_grad(): | |
| # Use image input if provided, else use text prompt | |
| if image_input: | |
| output = pipe( | |
| image=Image.open(image_input).convert("RGB"), | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| ).frames[0] | |
| else: | |
| output = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| ).frames[0] | |
| output_path = "output.mp4" | |
| export_to_video(output, output_path, fps=fps) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return output_path | |
| def apply_preset(preset_name, *current_values): | |
| if preset_name == "Higher Resolution": | |
| return [608, 448, 24, 29, 12] | |
| elif preset_name == "More Frames": | |
| return [512, 320, 42, 27, 14] | |
| return current_values | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 850px; | |
| } | |
| .dark-theme { | |
| background-color: #1f1f1f; | |
| color: #ffffff; | |
| } | |
| .container { | |
| margin: 0 auto; | |
| padding: 20px; | |
| border-radius: 10px; | |
| background-color: #2d2d2d; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .title { | |
| text-align: center; | |
| margin-bottom: 1em; | |
| color: #ffffff; | |
| } | |
| .description { | |
| text-align: center; | |
| margin-bottom: 2em; | |
| color: #cccccc; | |
| font-size: 0.95em; | |
| line-height: 1.5; | |
| } | |
| .prompt-container { | |
| background-color: #363636; | |
| padding: 15px; | |
| border-radius: 8px; | |
| margin-bottom: 1em; | |
| width: 100%; | |
| } | |
| .prompt-textbox { | |
| min-height: 80px !important; | |
| } | |
| .preset-buttons { | |
| display: flex; | |
| gap: 10px; | |
| justify-content: center; | |
| margin-bottom: 1em; | |
| } | |
| .support-text { | |
| text-align: center; | |
| margin-top: 1em; | |
| color: #cccccc; | |
| font-size: 0.9em; | |
| } | |
| a { | |
| color: #00a7e1; | |
| text-decoration: none; | |
| } | |
| a:hover { | |
| text-decoration: underline; | |
| } | |
| .lora-sliders { | |
| margin-top: 15px; | |
| border-top: 1px solid #444; | |
| padding-top: 15px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme="dark") as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# 🎬 Hunyuan Studio", elem_classes=["title"]) | |
| gr.Markdown( | |
| """Generate videos from text or images using multiple LoRA adapters. | |
| Requires matching resolution between input image and output settings.""", | |
| elem_classes=["description"] | |
| ) | |
| with gr.Column(elem_classes=["prompt-container"]): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter text prompt or upload image below", | |
| show_label=False, | |
| elem_classes=["prompt-textbox"], | |
| lines=3 | |
| ) | |
| image_input = gr.Image(type="filepath", label="Upload Image (Optional)") | |
| with gr.Row(): | |
| run_button = gr.Button("🎨 Generate", variant="primary", size="lg") | |
| with gr.Row(elem_classes=["preset-buttons"]): | |
| preset_high_res = gr.Button("📺 Higher Resolution Preset") | |
| preset_more_frames = gr.Button("🎞️ More Frames Preset") | |
| with gr.Row(): | |
| result = gr.Video(label="Generated Video") | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed (-1 for random)", | |
| minimum=-1, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=-1, | |
| ) | |
| with gr.Row(): | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=16, | |
| value=608, | |
| ) | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=16, | |
| value=448, | |
| ) | |
| with gr.Row(): | |
| num_frames = gr.Slider( | |
| label="Number of frames", | |
| minimum=1.0, | |
| maximum=257.0, | |
| step=1, | |
| value=24, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=29, | |
| ) | |
| fps = gr.Slider( | |
| label="Frames per second", | |
| minimum=1, | |
| maximum=60, | |
| step=1, | |
| value=12, | |
| ) | |
| with gr.Column(elem_classes=["lora-sliders"]): | |
| gr.Markdown("### LoRA Adapters") | |
| lora_checkboxes = gr.CheckboxGroup( | |
| label="Select LoRAs", | |
| choices=[display for (_, display) in LORA_CHOICES], | |
| value=["Stripe Style", "Top Off Effect"] | |
| ) | |
| lora_weight_sliders = [] | |
| for _, display_name in LORA_CHOICES: | |
| lora_weight_sliders.append( | |
| gr.Slider( | |
| label=f"{display_name} Weight", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9 if "Stripe" in display_name else 0.8, | |
| visible=False | |
| ) | |
| ) | |
| # Event handling | |
| run_button.click( | |
| fn=generate, | |
| inputs=[prompt, image_input, height, width, num_frames, | |
| num_inference_steps, seed, fps, lora_checkboxes, lora_weight_sliders], | |
| outputs=[result], | |
| ) | |
| # Preset button handlers | |
| preset_high_res.click( | |
| fn=lambda: apply_preset("Higher Resolution"), | |
| outputs=[height, width, num_frames, num_inference_steps, fps] | |
| ) | |
| preset_more_frames.click( | |
| fn=lambda: apply_preset("More Frames"), | |
| outputs=[height, width, num_frames, num_inference_steps, fps] | |
| ) | |
| # Show/hide LORA weight sliders based on checkbox selection | |
| def toggle_lora_sliders(selected_loras): | |
| updates = [] | |
| for lora in LORA_CHOICES: | |
| updates.append(gr.update(visible=lora[1] in selected_loras)) | |
| return updates | |
| lora_checkboxes.change( | |
| fn=toggle_lora_sliders, | |
| inputs=lora_checkboxes, | |
| outputs=lora_weight_sliders | |
| ) | |