File size: 6,701 Bytes
89138dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339a69e
 
 
89138dc
 
 
 
 
 
 
 
339a69e
 
 
 
 
 
89138dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339a69e
 
 
 
 
 
 
 
 
 
 
 
 
 
89138dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339a69e
89138dc
 
 
 
 
 
 
 
339a69e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89138dc
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import cv2 as cv
import numpy as np
import gradio as gr
from vittrack import VitTrack
from huggingface_hub import hf_hub_download
import os
import tempfile

# Download ONNX model at startup
MODEL_PATH = hf_hub_download(
    repo_id="opencv/object_tracking_vittrack",
    filename="object_tracking_vittrack_2023sep.onnx"
)

backend_id = cv.dnn.DNN_BACKEND_OPENCV
target_id  = cv.dnn.DNN_TARGET_CPU

car_on_road_video = "examples/car.mp4"
car_in_desert_video = "examples/desert_car.mp4"

# Global state
state = {
    "points": [],
    "bbox": None,
    "video_path": None,
    "first_frame": None
}

#Example bounding boxes
bbox_dict = {
    "car.mp4": "(152, 356, 332, 104)",
    "desert_car.mp4": "(758, 452, 119, 65)",
}

def load_first_frame(video_path):
    """Load video, grab first frame, reset state."""
    state["video_path"] = video_path
    cap = cv.VideoCapture(video_path)
    has_frame, frame = cap.read()
    cap.release()
    if not has_frame:
        return None
    state["first_frame"] = frame.copy()
    return cv.cvtColor(frame, cv.COLOR_BGR2RGB)

def select_point(img, evt: gr.SelectData):
    """Accumulate up to 4 clicks, draw polygon + bounding box."""
    if state["first_frame"] is None:
        return None

    x, y = int(evt.index[0]), int(evt.index[1])
    if len(state["points"]) < 4:
        state["points"].append((x, y))

    vis = state["first_frame"].copy()
    # draw each point
    for pt in state["points"]:
        cv.circle(vis, pt, 5, (0, 255, 0), -1)
    # draw connecting polygon
    if len(state["points"]) > 1:
        pts = np.array(state["points"], dtype=np.int32)
        cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2)

    # once we have exactly 4, compute & draw bounding rect
    if len(state["points"]) == 4:
        pts = np.array(state["points"], dtype=np.int32)
        x0, y0, w, h = cv.boundingRect(pts)
        state["bbox"] = (x0, y0, w, h)
        cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2)

    return cv.cvtColor(vis, cv.COLOR_BGR2RGB)

def clear_points():
    """Reset selected points only."""
    state["points"].clear()
    state["bbox"] = None
    if state["first_frame"] is None:
        return None
    return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB)

def clear_all():
    """Reset everything."""
    state["points"].clear()
    state["bbox"] = None
    state["video_path"] = None
    state["first_frame"] = None
    return None, None, None

def track_video():
    """Init VitTrack and process entire video, return output path."""
    if state["video_path"] is None or state["bbox"] is None:
        return None

    # instantiate VitTrack
    model = VitTrack(
        model_path=MODEL_PATH,
        backend_id=backend_id,
        target_id= target_id
    )

    cap = cv.VideoCapture(state["video_path"])
    fps = cap.get(cv.CAP_PROP_FPS)
    w   = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
    h   = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))

    # prepare temporary output file
    tmpdir = tempfile.gettempdir()
    out_path = os.path.join(tmpdir, "vittrack_output.mp4")
    writer = cv.VideoWriter(
        out_path,
        cv.VideoWriter_fourcc(*"mp4v"),
        fps,
        (w, h)
    )

    # read & init on first frame
    _, first_frame = cap.read()
    model.init(first_frame, state["bbox"])

    tm = cv.TickMeter()
    while True:
        has_frame, frame = cap.read()
        if not has_frame:
            break
        tm.start()
        isLocated, bbox, score = model.infer(frame)
        tm.stop()

        vis = frame.copy()
        # overlay FPS
        cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30),
                   cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        # draw tracking box or loss message
        if isLocated and score >= 0.3:
            x, y, w_, h_ = bbox
            cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2)
            cv.putText(vis, f"{score:.2f}", (x, y - 10),
                       cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        else:
            cv.putText(vis, "Target lost!",
                       (w // 2, h//4),
                       cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)

        writer.write(vis)
        tm.reset()

    cap.release()
    writer.release()
    return out_path

def example_pipeline(video_path):
    clear_all()

    filename = video_path.split('/')[-1]
    state["video_path"] = video_path
    state["bbox"] = eval(bbox_dict[filename])

    return track_video()

with gr.Blocks(css='''.example * {
    font-style: italic;
    font-size: 18px !important;
    color: #0ea5e9 !important;
    }''') as demo:

    gr.Markdown("## VitTrack: Interactive Video Object Tracking")
    gr.Markdown(
        """
        **How to use this tool:**

        1. **Upload a video** file (e.g., `.mp4` or `.avi`).
        2. The **first frame** of the video will appear.
        3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible.
        4. A **bounding box** will be drawn around the selected region automatically.
        5. Click the **Track** button to start object tracking across the entire video.
        6. The output video with tracking overlay will appear below.

        You can also use:
        - ๐Ÿงน **Clear Points** to reset the 4-point selection on the first frame.
        - ๐Ÿ”„ **Clear All** to reset the uploaded video, frame, and selections.
        """
    )

    with gr.Row():
        video_in     = gr.Video(label="Upload Video")
        first_frame  = gr.Image(label="First Frame", interactive=True)
        output_video = gr.Video(label="Tracking Result")

    with gr.Row():
        track_btn     = gr.Button("Track", variant="primary")
        clear_pts_btn = gr.Button("Clear Points")
        clear_all_btn = gr.Button("Clear All")

    gr.Markdown("Click any row to load an example.", elem_classes=["example"])
    
    examples = [
        [car_on_road_video],
        [car_in_desert_video],
    ]

    gr.Examples(
        examples=examples,
        inputs=[video_in],
        outputs=[output_video],
        fn=example_pipeline,
        cache_examples=False,
        run_on_click=True 
    )

    gr.Markdown("Example videos credit: https://pixabay.com/")

    video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame)
    first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame)
    clear_pts_btn.click(fn=clear_points, outputs=first_frame)
    clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video])
    track_btn.click(fn=track_video, outputs=output_video)

if __name__ == "__main__":
    demo.launch()