Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from src.augmentations import get_videomae_transform | |
| from src.models import load_model | |
| from src.utils import ( | |
| create_plot, | |
| get_frames, | |
| get_videomae_outputs, | |
| prepare_frames_masks, | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def get_visualisations(mask_ratio, video_path): | |
| transform = get_videomae_transform() | |
| frames, ids = get_frames(path=video_path, transform=transform) | |
| model, masks, patch_size = load_model( | |
| path="assets/checkpoint.pth", | |
| mask_ratio=mask_ratio, | |
| device=device, | |
| ) | |
| with torch.no_grad(): | |
| frames, masks = prepare_frames_masks(frames, masks, device) | |
| outputs = model(frames, masks) | |
| visualisations = get_videomae_outputs( | |
| frames=frames, | |
| masks=masks, | |
| outputs=outputs, | |
| ids=ids, | |
| patch_size=patch_size, | |
| device=device, | |
| ) | |
| return create_plot(visualisations) | |
| with gr.Blocks() as app: | |
| gr.Markdown( | |
| """ | |
| # VideoMAE Reconstruction Demo | |
| To read more about the Self-Supervised Learning techniques for video please refer to the [Lightly AI blogpost on Self-Supervised Learning for Videos](www.lightly.ai/post/self-supervised-learning-for-videos). | |
| """ # noqa: E501 | |
| ) | |
| video = gr.Video( | |
| value="assets/example.mp4", | |
| ) | |
| mask_ratio_slider = gr.Slider( | |
| minimum=0.25, maximum=0.95, step=0.05, value=0.75, label="masking ratio" | |
| ) | |
| btn = gr.Button("Run") | |
| btn.click( | |
| get_visualisations, | |
| inputs=[mask_ratio_slider, video], | |
| outputs=gr.Plot(label="VideoMAE Outputs", format="png"), | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |