Spaces:
Sleeping
Sleeping
File size: 6,701 Bytes
89138dc 339a69e 89138dc 339a69e 89138dc 339a69e 89138dc 339a69e 89138dc 339a69e 89138dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import cv2 as cv
import numpy as np
import gradio as gr
from vittrack import VitTrack
from huggingface_hub import hf_hub_download
import os
import tempfile
# Download ONNX model at startup
MODEL_PATH = hf_hub_download(
repo_id="opencv/object_tracking_vittrack",
filename="object_tracking_vittrack_2023sep.onnx"
)
backend_id = cv.dnn.DNN_BACKEND_OPENCV
target_id = cv.dnn.DNN_TARGET_CPU
car_on_road_video = "examples/car.mp4"
car_in_desert_video = "examples/desert_car.mp4"
# Global state
state = {
"points": [],
"bbox": None,
"video_path": None,
"first_frame": None
}
#Example bounding boxes
bbox_dict = {
"car.mp4": "(152, 356, 332, 104)",
"desert_car.mp4": "(758, 452, 119, 65)",
}
def load_first_frame(video_path):
"""Load video, grab first frame, reset state."""
state["video_path"] = video_path
cap = cv.VideoCapture(video_path)
has_frame, frame = cap.read()
cap.release()
if not has_frame:
return None
state["first_frame"] = frame.copy()
return cv.cvtColor(frame, cv.COLOR_BGR2RGB)
def select_point(img, evt: gr.SelectData):
"""Accumulate up to 4 clicks, draw polygon + bounding box."""
if state["first_frame"] is None:
return None
x, y = int(evt.index[0]), int(evt.index[1])
if len(state["points"]) < 4:
state["points"].append((x, y))
vis = state["first_frame"].copy()
# draw each point
for pt in state["points"]:
cv.circle(vis, pt, 5, (0, 255, 0), -1)
# draw connecting polygon
if len(state["points"]) > 1:
pts = np.array(state["points"], dtype=np.int32)
cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2)
# once we have exactly 4, compute & draw bounding rect
if len(state["points"]) == 4:
pts = np.array(state["points"], dtype=np.int32)
x0, y0, w, h = cv.boundingRect(pts)
state["bbox"] = (x0, y0, w, h)
cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2)
return cv.cvtColor(vis, cv.COLOR_BGR2RGB)
def clear_points():
"""Reset selected points only."""
state["points"].clear()
state["bbox"] = None
if state["first_frame"] is None:
return None
return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB)
def clear_all():
"""Reset everything."""
state["points"].clear()
state["bbox"] = None
state["video_path"] = None
state["first_frame"] = None
return None, None, None
def track_video():
"""Init VitTrack and process entire video, return output path."""
if state["video_path"] is None or state["bbox"] is None:
return None
# instantiate VitTrack
model = VitTrack(
model_path=MODEL_PATH,
backend_id=backend_id,
target_id= target_id
)
cap = cv.VideoCapture(state["video_path"])
fps = cap.get(cv.CAP_PROP_FPS)
w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
# prepare temporary output file
tmpdir = tempfile.gettempdir()
out_path = os.path.join(tmpdir, "vittrack_output.mp4")
writer = cv.VideoWriter(
out_path,
cv.VideoWriter_fourcc(*"mp4v"),
fps,
(w, h)
)
# read & init on first frame
_, first_frame = cap.read()
model.init(first_frame, state["bbox"])
tm = cv.TickMeter()
while True:
has_frame, frame = cap.read()
if not has_frame:
break
tm.start()
isLocated, bbox, score = model.infer(frame)
tm.stop()
vis = frame.copy()
# overlay FPS
cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30),
cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
# draw tracking box or loss message
if isLocated and score >= 0.3:
x, y, w_, h_ = bbox
cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2)
cv.putText(vis, f"{score:.2f}", (x, y - 10),
cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
else:
cv.putText(vis, "Target lost!",
(w // 2, h//4),
cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
writer.write(vis)
tm.reset()
cap.release()
writer.release()
return out_path
def example_pipeline(video_path):
clear_all()
filename = video_path.split('/')[-1]
state["video_path"] = video_path
state["bbox"] = eval(bbox_dict[filename])
return track_video()
with gr.Blocks(css='''.example * {
font-style: italic;
font-size: 18px !important;
color: #0ea5e9 !important;
}''') as demo:
gr.Markdown("## VitTrack: Interactive Video Object Tracking")
gr.Markdown(
"""
**How to use this tool:**
1. **Upload a video** file (e.g., `.mp4` or `.avi`).
2. The **first frame** of the video will appear.
3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible.
4. A **bounding box** will be drawn around the selected region automatically.
5. Click the **Track** button to start object tracking across the entire video.
6. The output video with tracking overlay will appear below.
You can also use:
- ๐งน **Clear Points** to reset the 4-point selection on the first frame.
- ๐ **Clear All** to reset the uploaded video, frame, and selections.
"""
)
with gr.Row():
video_in = gr.Video(label="Upload Video")
first_frame = gr.Image(label="First Frame", interactive=True)
output_video = gr.Video(label="Tracking Result")
with gr.Row():
track_btn = gr.Button("Track", variant="primary")
clear_pts_btn = gr.Button("Clear Points")
clear_all_btn = gr.Button("Clear All")
gr.Markdown("Click any row to load an example.", elem_classes=["example"])
examples = [
[car_on_road_video],
[car_in_desert_video],
]
gr.Examples(
examples=examples,
inputs=[video_in],
outputs=[output_video],
fn=example_pipeline,
cache_examples=False,
run_on_click=True
)
gr.Markdown("Example videos credit: https://pixabay.com/")
video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame)
first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame)
clear_pts_btn.click(fn=clear_points, outputs=first_frame)
clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video])
track_btn.click(fn=track_video, outputs=output_video)
if __name__ == "__main__":
demo.launch()
|