import os import gradio as gr try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False from longstream.demo import BRANCH_OPTIONS, create_demo_session, load_metadata from longstream.demo.backend import load_frame_previews from longstream.demo.export import export_glb from longstream.demo.viewer import build_interactive_figure DEFAULT_KEYFRAME_STRIDE = 8 DEFAULT_REFRESH = 3 DEFAULT_WINDOW_SIZE = 48 DEFAULT_CHECKPOINT = os.getenv("LONGSTREAM_CHECKPOINT", "checkpoints/50_longstream.pt") def _run_stable_demo_impl( image_dir, uploaded_files, uploaded_video, checkpoint, device, mode, streaming_mode, refresh, window_size, compute_sky, branch_label, show_cameras, mask_sky, camera_scale, point_size, opacity, preview_max_points, glb_max_points, ): if not image_dir and not uploaded_files and not uploaded_video: raise gr.Error("Provide an image folder, upload images, or upload a video.") session_dir = create_demo_session( image_dir=image_dir or "", uploaded_files=uploaded_files, uploaded_video=uploaded_video, checkpoint=checkpoint, device=device, mode=mode, streaming_mode=streaming_mode, keyframe_stride=DEFAULT_KEYFRAME_STRIDE, refresh=int(refresh), window_size=int(window_size), compute_sky=bool(compute_sky), ) fig = build_interactive_figure( session_dir=session_dir, branch=branch_label, display_mode="All Frames", frame_index=0, point_size=float(point_size), opacity=float(opacity), preview_max_points=int(preview_max_points), show_cameras=bool(show_cameras), camera_scale=float(camera_scale), mask_sky=bool(mask_sky), ) glb_path = export_glb( session_dir=session_dir, branch=branch_label, display_mode="All Frames", frame_index=0, mask_sky=bool(mask_sky), show_cameras=bool(show_cameras), camera_scale=float(camera_scale), max_points=int(glb_max_points), ) rgb, depth, frame_label = load_frame_previews(session_dir, 0) meta = load_metadata(session_dir) slider = gr.update( minimum=0, maximum=max(meta["num_frames"] - 1, 0), value=0, step=1, interactive=meta["num_frames"] > 1, ) sky_msg = "" if meta.get("has_sky_masks"): removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0 sky_msg = f" | sky_removed={removed:.1f}%" status = f"Ready: {meta['num_frames']} frames | branch={branch_label}{sky_msg}" return ( fig, glb_path, session_dir, rgb, depth, frame_label, slider, status, ) if HAS_SPACES: _run_stable_demo = spaces.GPU(_run_stable_demo_impl) else: _run_stable_demo = _run_stable_demo_impl def _update_stable_scene( session_dir, branch_label, show_cameras, mask_sky, camera_scale, point_size, opacity, preview_max_points, glb_max_points, ): if not session_dir or not os.path.isdir(session_dir): return None, None, "Run reconstruction first." fig = build_interactive_figure( session_dir=session_dir, branch=branch_label, display_mode="All Frames", frame_index=0, point_size=float(point_size), opacity=float(opacity), preview_max_points=int(preview_max_points), show_cameras=bool(show_cameras), camera_scale=float(camera_scale), mask_sky=bool(mask_sky), ) glb_path = export_glb( session_dir=session_dir, branch=branch_label, display_mode="All Frames", frame_index=0, mask_sky=bool(mask_sky), show_cameras=bool(show_cameras), camera_scale=float(camera_scale), max_points=int(glb_max_points), ) meta = load_metadata(session_dir) sky_msg = "" if meta.get("has_sky_masks"): removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0 sky_msg = f" | sky_removed={removed:.1f}%" return fig, glb_path, f"Updated preview: {branch_label}{sky_msg}" def _update_frame_preview(session_dir, frame_index): if not session_dir or not os.path.isdir(session_dir): return None, None, "" rgb, depth, label = load_frame_previews(session_dir, int(frame_index)) return rgb, depth, label def main(): with gr.Blocks(title="LongStream Demo") as demo: session_dir = gr.Textbox(visible=False) gr.Markdown("# LongStream Demo") with gr.Row(): image_dir = gr.Textbox( label="Image Folder", placeholder="/path/to/sequence" ) uploaded_files = gr.File( label="Upload Images", file_count="multiple", file_types=["image"] ) uploaded_video = gr.File( label="Upload Video", file_count="single", file_types=["video"] ) with gr.Row(): checkpoint = gr.Textbox(label="Checkpoint", value=DEFAULT_CHECKPOINT) device = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cuda") with gr.Accordion("Inference", open=False): with gr.Row(): mode = gr.Dropdown( label="Mode", choices=["streaming_refresh", "batch_refresh"], value="batch_refresh", ) streaming_mode = gr.Dropdown( label="Streaming Mode", choices=["causal", "window"], value="causal" ) with gr.Row(): refresh = gr.Slider( label="Refresh", minimum=2, maximum=9, step=1, value=DEFAULT_REFRESH ) window_size = gr.Slider( label="Window Size", minimum=1, maximum=64, step=1, value=DEFAULT_WINDOW_SIZE, ) compute_sky = gr.Checkbox(label="Compute Sky Masks", value=True) with gr.Accordion("GLB Settings", open=True): with gr.Row(): branch_label = gr.Dropdown( label="Point Cloud Branch", choices=BRANCH_OPTIONS, value="Point Head + Pose", ) show_cameras = gr.Checkbox(label="Show Cameras", value=True) mask_sky = gr.Checkbox(label="Mask Sky", value=True) with gr.Row(): point_size = gr.Slider( label="Point Size", minimum=0.05, maximum=2.0, step=0.05, value=0.3, ) opacity = gr.Slider( label="Opacity", minimum=0.1, maximum=1.0, step=0.05, value=0.75, ) preview_max_points = gr.Slider( label="Preview Max Points", minimum=5000, maximum=1000000, step=10000, value=100000, ) with gr.Row(): camera_scale = gr.Slider( label="Camera Scale", minimum=0.001, maximum=0.05, step=0.001, value=0.01, ) glb_max_points = gr.Slider( label="GLB Max Points", minimum=20000, maximum=1000000, step=10000, value=400000, ) run_btn = gr.Button("Run Stable Demo", variant="primary") status = gr.Markdown("Provide input images, then run reconstruction.") plot = gr.Plot(label="Scene Preview") glb_file = gr.File(label="Download GLB") with gr.Row(): frame_slider = gr.Slider( label="Preview Frame", minimum=0, maximum=0, step=1, value=0, interactive=False, ) frame_label = gr.Textbox(label="Frame") with gr.Row(): rgb_preview = gr.Image(label="RGB", type="numpy") depth_preview = gr.Image(label="Depth Plasma", type="numpy") run_btn.click( _run_stable_demo, inputs=[ image_dir, uploaded_files, uploaded_video, checkpoint, device, mode, streaming_mode, refresh, window_size, compute_sky, branch_label, show_cameras, mask_sky, camera_scale, point_size, opacity, preview_max_points, glb_max_points, ], outputs=[ plot, glb_file, session_dir, rgb_preview, depth_preview, frame_label, frame_slider, status, ], ) for component in [ branch_label, show_cameras, mask_sky, camera_scale, point_size, opacity, preview_max_points, glb_max_points, ]: component.change( _update_stable_scene, inputs=[ session_dir, branch_label, show_cameras, mask_sky, camera_scale, point_size, opacity, preview_max_points, glb_max_points, ], outputs=[plot, glb_file, status], ) frame_slider.change( _update_frame_preview, inputs=[session_dir, frame_slider], outputs=[rgb_preview, depth_preview, frame_label], ) demo.launch() if __name__ == "__main__": main()