Ruggero1912 commited on
Commit
72775f2
Β·
1 Parent(s): 2f13aa0

added app.py file

Browse files
Files changed (1) hide show
  1. app.py +772 -0
app.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Demo App for Patchioner Model - Trace-based Image Captioning
4
+
5
+ This demo allows users to:
6
+ 1. Upload or select an image
7
+ 2. Draw traces on the image using Gradio's ImageEditor
8
+ 3. Generate captions for the traced regions using a pre-trained Patchioner model
9
+
10
+ Author: Generated for decap-dino project
11
+ """
12
+
13
+ import sys
14
+ import os
15
+
16
+ import gradio as gr
17
+
18
+ from gradio_image_annotation import image_annotator as foo_image_annotator
19
+
20
+ import torch
21
+ import yaml
22
+ import json
23
+ import traceback
24
+ from pathlib import Path
25
+ from PIL import Image
26
+ import numpy as np
27
+ from typing import List, Dict, Optional
28
+
29
+ # Import the Patchioner model from the src directory
30
+ from src.model import Patchioner
31
+
32
+ # Global variable to store the loaded model
33
+ loaded_model = None
34
+ model_config_path = None
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ # Default model configuration
38
+ DEFAULT_MODEL_CONFIG = "mlp.k.yaml"
39
+
40
+ # Example images directory
41
+ current_dir = os.path.dirname(__file__)
42
+ EXAMPLE_IMAGES_DIR = Path(os.path.join(current_dir, 'example-images')).resolve()
43
+ CONFIGS_DIR = Path(os.path.join(current_dir, '../Patch-ioner/configs')).resolve()
44
+
45
+
46
+ def initialize_default_model() -> str:
47
+ """Initialize the default model at startup."""
48
+ global loaded_model, model_config_path
49
+
50
+ try:
51
+ # Look for the default config file
52
+ default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG
53
+
54
+ if not default_config_path.exists():
55
+ return f"❌ Default config file not found: {default_config_path}"
56
+
57
+ print(f"Loading default model: {DEFAULT_MODEL_CONFIG}")
58
+
59
+ # Load and parse the config
60
+ with open(default_config_path, 'r') as f:
61
+ config = yaml.safe_load(f)
62
+
63
+ # Load the model using the from_config class method
64
+ model = Patchioner.from_config(config, device=device)
65
+ model.eval()
66
+ model.to(device)
67
+
68
+ # Store the model globally
69
+ loaded_model = model
70
+ model_config_path = str(default_config_path)
71
+
72
+ return f"βœ… Default model loaded: {DEFAULT_MODEL_CONFIG} on {device}"
73
+
74
+ except Exception as e:
75
+ error_msg = f"❌ Error loading default model: {str(e)}"
76
+ print(error_msg)
77
+ print(traceback.format_exc())
78
+ return error_msg
79
+
80
+
81
+ def get_example_images(limit=None) -> List[str]:
82
+ """Get list of example images for the demo."""
83
+ example_images = []
84
+ if EXAMPLE_IMAGES_DIR.exists():
85
+ for ext in ['*.jpg', '*.jpeg', '*.png']:
86
+ example_images.extend(str(p) for p in EXAMPLE_IMAGES_DIR.glob(ext))
87
+ if limit is not None:
88
+ example_images = example_images[:limit]
89
+ return example_images
90
+
91
+
92
+ def get_example_configs() -> List[str]:
93
+ """Get list of example config files."""
94
+ example_configs = []
95
+ if CONFIGS_DIR.exists():
96
+ example_configs = [str(p) for p in CONFIGS_DIR.glob("*.yaml")]
97
+ else:
98
+ print(f"Warning: Configs directory {CONFIGS_DIR} does not exist.")
99
+ return sorted(example_configs)
100
+
101
+
102
+ def load_model_from_config(config_file_path: str) -> str:
103
+ """
104
+ Load the Patchioner model from a config file.
105
+
106
+ Args:
107
+ config_file_path: Path to the YAML configuration file
108
+
109
+ Returns:
110
+ Status message about model loading
111
+ """
112
+ global loaded_model, model_config_path
113
+
114
+ try:
115
+ if not config_file_path or not os.path.exists(config_file_path):
116
+ return "❌ Error: Config file path is empty or file does not exist."
117
+
118
+ print(f"Loading model from config: {config_file_path}")
119
+
120
+ # Load and parse the config
121
+ with open(config_file_path, 'r') as f:
122
+ config = yaml.safe_load(f)
123
+
124
+ # Load the model using the from_config class method
125
+ model = Patchioner.from_config(config, device=device)
126
+ model.eval()
127
+ model.to(device)
128
+
129
+ # Store the model globally
130
+ loaded_model = model
131
+ model_config_path = config_file_path
132
+
133
+ return f"βœ… Model loaded successfully from {os.path.basename(config_file_path)} on {device}"
134
+
135
+ except Exception as e:
136
+ error_msg = f"❌ Error loading model: {str(e)}"
137
+ print(error_msg)
138
+ print(traceback.format_exc())
139
+ return error_msg
140
+
141
+
142
+ def process_image_trace_to_coordinates(image_editor_data) -> List[List[Dict[str, float]]]:
143
+ """
144
+ Convert Gradio ImageEditor trace data to the coordinate format expected by the model.
145
+
146
+ The expected format is: [[{"x": float, "y": float, "t": float}, ...], ...]
147
+ where coordinates are normalized to [0, 1] and t is a timestamp.
148
+
149
+ Args:
150
+ image_editor_data: Data from Gradio ImageEditor component
151
+
152
+ Returns:
153
+ List of traces in the expected format
154
+ """
155
+ try:
156
+ print(f"[DEBUG] process_image_trace_to_coordinates called")
157
+ print(f"[DEBUG] image_editor_data type: {type(image_editor_data)}")
158
+
159
+ if image_editor_data is None:
160
+ print("[DEBUG] image_editor_data is None")
161
+ return []
162
+
163
+ if isinstance(image_editor_data, dict):
164
+ print(f"[DEBUG] Available keys in image_editor_data: {list(image_editor_data.keys())}")
165
+
166
+ # Check for different possible structures
167
+ layers = None
168
+ if isinstance(image_editor_data, dict):
169
+ if 'layers' in image_editor_data:
170
+ layers = image_editor_data['layers']
171
+ elif 'composite' in image_editor_data:
172
+ # Sometimes gradio stores drawing data differently
173
+ composite = image_editor_data['composite']
174
+ if isinstance(composite, dict) and 'layers' in composite:
175
+ layers = composite['layers']
176
+
177
+ if not layers:
178
+ print("[DEBUG] No layers found in image_editor_data")
179
+ return []
180
+
181
+ traces = []
182
+ print(f"[DEBUG] Processing {len(layers)} layers")
183
+
184
+ # Process each drawing layer - they are PIL Images, not coordinate data
185
+ for i, layer in enumerate(layers):
186
+ print(f"[DEBUG] Processing layer {i}: {layer}")
187
+
188
+ # Skip if layer is not a PIL Image or is empty
189
+ if not isinstance(layer, Image.Image):
190
+ print(f"[DEBUG] Layer {i} is not a PIL Image")
191
+ continue
192
+
193
+ # Convert layer to numpy array to find non-transparent pixels
194
+ layer_array = np.array(layer)
195
+
196
+ # Find non-transparent pixels (alpha > 0)
197
+ if layer_array.shape[2] == 4: # RGBA
198
+ non_transparent = layer_array[:, :, 3] > 0
199
+ else: # RGB - assume any non-black pixel is drawn
200
+ non_transparent = np.any(layer_array > 0, axis=2)
201
+
202
+ # Get coordinates of drawn pixels
203
+ y_coords, x_coords = np.where(non_transparent)
204
+
205
+ if len(x_coords) == 0:
206
+ print(f"[DEBUG] Layer {i} has no drawn pixels")
207
+ continue
208
+
209
+ print(f"[DEBUG] Layer {i} has {len(x_coords)} drawn pixels")
210
+
211
+ # Convert pixel coordinates to trace format
212
+ trace_points = []
213
+ img_height, img_width = layer_array.shape[:2]
214
+
215
+ # Sample some points from the drawn pixels (to avoid too many points)
216
+ num_points = min(len(x_coords), 100) # Limit to 100 points max
217
+ if num_points > 0:
218
+ # Sample evenly spaced indices
219
+ indices = np.linspace(0, len(x_coords) - 1, num_points, dtype=int)
220
+ sampled_x = x_coords[indices]
221
+ sampled_y = y_coords[indices]
222
+
223
+ # Convert to normalized coordinates and create trace points
224
+ for idx, (x, y) in enumerate(zip(sampled_x, sampled_y)):
225
+ # Normalize coordinates to [0, 1]
226
+ x_norm = float(x) / img_width if img_width > 0 else 0
227
+ y_norm = float(y) / img_height if img_height > 0 else 0
228
+
229
+ # Clamp to [0, 1] range
230
+ x_norm = max(0, min(1, x_norm))
231
+ y_norm = max(0, min(1, y_norm))
232
+
233
+ # Add timestamp (arbitrary progression)
234
+ t = idx * 0.1
235
+
236
+ trace_points.append({
237
+ "x": x_norm,
238
+ "y": y_norm,
239
+ "t": t
240
+ })
241
+
242
+ if trace_points:
243
+ traces.append(trace_points)
244
+
245
+ return traces
246
+
247
+ except Exception as e:
248
+ print(f"Error processing image trace: {e}")
249
+ print(traceback.format_exc())
250
+ return []
251
+
252
+
253
+ def process_bounding_box_coordinates(annotator_data) -> List[List[float]]:
254
+ """
255
+ Convert Gradio image_annotator data to bounding box format expected by the model.
256
+
257
+ Args:
258
+ annotator_data: Data from Gradio image_annotator component
259
+
260
+ Returns:
261
+ List of bounding boxes in [x, y, width, height] format
262
+ """
263
+ try:
264
+ print(f"[DEBUG] process_bounding_box_coordinates called")
265
+ print(f"[DEBUG] annotator_data type: {type(annotator_data)}")
266
+ #print(f"[DEBUG] annotator_data content: {annotator_data}")
267
+
268
+ if annotator_data is None:
269
+ print("[DEBUG] annotator_data is None")
270
+ return []
271
+
272
+ boxes = []
273
+
274
+ # Handle the dictionary format from image_annotator
275
+ if isinstance(annotator_data, dict):
276
+ print(f"[DEBUG] Available keys in annotator_data: {list(annotator_data.keys())}")
277
+
278
+ # Extract boxes from the 'boxes' key
279
+ if 'boxes' in annotator_data and annotator_data['boxes']:
280
+ for box in annotator_data['boxes']:
281
+ if isinstance(box, dict):
282
+ # Based on image_annotator.py, boxes have format:
283
+ # {"xmin": x, "ymin": y, "xmax": x2, "ymax": y2, "label": ..., "color": ...}
284
+ xmin = box.get('xmin', 0)
285
+ ymin = box.get('ymin', 0)
286
+ xmax = box.get('xmax', 0)
287
+ ymax = box.get('ymax', 0)
288
+
289
+ width = xmax - xmin
290
+ height = ymax - ymin
291
+
292
+ # Convert to [x, y, width, height] format
293
+ boxes.append([xmin, ymin, width, height])
294
+ else:
295
+ print("[DEBUG] No 'boxes' key found or boxes list is empty")
296
+
297
+ print(f"[DEBUG] Found {len(boxes)} bounding boxes: {boxes}")
298
+ return boxes
299
+
300
+ except Exception as e:
301
+ print(f"Error processing bounding box: {e}")
302
+ print(traceback.format_exc())
303
+ return []
304
+
305
+
306
+ def generate_caption(mode, image_data) -> str:
307
+ """
308
+ Generate caption for the image and traces/bboxes using the loaded model.
309
+
310
+ Args:
311
+ mode: Either "trace" or "bbox" mode
312
+ image_data: Data from Gradio ImageEditor or Annotate component
313
+
314
+ Returns:
315
+ Generated caption or error message
316
+ """
317
+ global loaded_model
318
+
319
+ try:
320
+ print(f"[DEBUG] generate_caption called with mode: {mode}")
321
+ print(f"[DEBUG] image_data type: {type(image_data)}")
322
+ print(f"[DEBUG] image_data content: {image_data}")
323
+
324
+ if loaded_model is None:
325
+ return "❌ Error: No model loaded. Please load a model first using the config file."
326
+
327
+ # Handle different input formats from Gradio components
328
+ image = None
329
+ if image_data is None:
330
+ return "❌ Error: No image data provided."
331
+
332
+ # Check if it's a PIL Image directly
333
+ if isinstance(image_data, Image.Image):
334
+ print("[DEBUG] Received PIL Image directly")
335
+ image = image_data
336
+ # Check if it's a dict (from image_annotator component)
337
+ elif isinstance(image_data, dict):
338
+ print(f"[DEBUG] Received dict with keys: {list(image_data.keys())}")
339
+ if 'image' in image_data:
340
+ image_array = image_data['image']
341
+ # Convert numpy array to PIL Image if needed
342
+ if hasattr(image_array, 'shape') and len(image_array.shape) == 3:
343
+ print("[DEBUG] Converting numpy array to PIL Image")
344
+ image = Image.fromarray(image_array)
345
+ else:
346
+ image = image_array
347
+ elif 'background' in image_data:
348
+ image_array = image_data['background']
349
+ # Convert numpy array to PIL Image if needed
350
+ if hasattr(image_array, 'shape') and len(image_array.shape) == 3:
351
+ print("[DEBUG] Converting numpy array to PIL Image")
352
+ image = Image.fromarray(image_array)
353
+ else:
354
+ image = image_array
355
+ else:
356
+ return f"❌ Error: No image found in data. Available keys: {list(image_data.keys())}"
357
+ # Check for tuple/list format (from ImageEditor component)
358
+ elif isinstance(image_data, (tuple, list)) and len(image_data) >= 1:
359
+ print(f"[DEBUG] Received tuple/list with {len(image_data)} elements")
360
+ image = image_data[0] # First element should be the image
361
+ if not isinstance(image, Image.Image):
362
+ # Sometimes the structure might be different, search for PIL Image
363
+ for item in image_data:
364
+ if isinstance(item, Image.Image):
365
+ image = item
366
+ break
367
+ else:
368
+ return f"❌ Error: Unexpected data type: {type(image_data)}"
369
+
370
+ if image is None:
371
+ return "❌ Error: Image is None."
372
+
373
+ # Convert PIL image if necessary
374
+ if not isinstance(image, Image.Image):
375
+ return "❌ Error: Invalid image format."
376
+
377
+ # Convert image to RGB if needed
378
+ if image.mode != 'RGB':
379
+ image = image.convert('RGB')
380
+
381
+ if mode == "trace":
382
+ return generate_trace_caption(image_data, image)
383
+ elif mode == "bbox":
384
+ return generate_bbox_caption(image_data, image)
385
+ else:
386
+ return f"❌ Error: Unknown mode: {mode}"
387
+
388
+ except Exception as e:
389
+ error_msg = f"❌ Error generating caption: {str(e)}"
390
+ print(error_msg)
391
+ print(traceback.format_exc())
392
+ return error_msg
393
+
394
+
395
+ def generate_trace_caption(image_data, image) -> str:
396
+ """Generate caption using traces."""
397
+ global loaded_model
398
+
399
+ try:
400
+ # Process traces
401
+ print("[DEBUG] Processing traces...")
402
+ traces = process_image_trace_to_coordinates(image_data)
403
+ print(f"[DEBUG] Found {len(traces)} traces")
404
+
405
+ if not traces:
406
+ # For debugging, let's generate a simple image caption instead of failing
407
+ print("[DEBUG] No traces found, generating image caption instead")
408
+ image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device)
409
+
410
+ with torch.no_grad():
411
+ outputs = loaded_model(
412
+ image_tensor,
413
+ get_cls_capt=True, # Get class caption as fallback
414
+ get_patch_capts=False,
415
+ get_avg_patch_capt=False
416
+ )
417
+
418
+ if 'cls_capt' in outputs:
419
+ return f"πŸ” No traces drawn. Image caption: {outputs['cls_capt']}"
420
+ else:
421
+ return "❌ Error: No traces detected. Please draw some traces on the image."
422
+
423
+ print(f"Processing {len(traces)} traces")
424
+
425
+ # Prepare image tensor
426
+ image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device)
427
+
428
+ # Generate caption using the model
429
+ with torch.no_grad():
430
+ outputs = loaded_model(
431
+ image_tensor,
432
+ traces=traces,
433
+ get_cls_capt=False, # We want trace captions, not class captions
434
+ get_patch_capts=False,
435
+ get_avg_patch_capt=False
436
+ )
437
+
438
+ # Extract the trace captions
439
+ if 'trace_capts' in outputs:
440
+ captions = outputs['trace_capts']
441
+ if isinstance(captions, list) and captions:
442
+ captions = [cap.replace("<|startoftext|>", "").replace("<|endoftext|>", "") for cap in captions]
443
+ # Join multiple captions if there are multiple traces
444
+ if len(captions) == 1:
445
+ return f"Generated Caption: {captions[0]}"
446
+ else:
447
+ formatted_captions = []
448
+ for i, caption in enumerate(captions, 1):
449
+ formatted_captions.append(f"Trace {i}: {caption}")
450
+ return "Generated Captions:\n" + "\n".join(formatted_captions)
451
+ elif isinstance(captions, str):
452
+ return f"Generated Caption: {captions}"
453
+ else:
454
+ return "❌ Error: No captions generated."
455
+ else:
456
+ return "❌ Error: Model did not return trace captions."
457
+
458
+ except Exception as e:
459
+ error_msg = f"❌ Error generating trace caption: {str(e)}"
460
+ print(error_msg)
461
+ print(traceback.format_exc())
462
+ return error_msg
463
+
464
+
465
+ def generate_bbox_caption(image_data, image) -> str:
466
+ """Generate caption using bounding boxes."""
467
+ global loaded_model
468
+
469
+ try:
470
+ # Process bounding boxes
471
+ print("[DEBUG] Processing bounding boxes...")
472
+ bboxes = process_bounding_box_coordinates(image_data)
473
+ print(f"[DEBUG] Found {len(bboxes)} bounding boxes")
474
+
475
+ if not bboxes:
476
+ # For debugging, let's generate a simple image caption instead of failing
477
+ print("[DEBUG] No bounding boxes found, generating image caption instead")
478
+ image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device)
479
+
480
+ with torch.no_grad():
481
+ outputs = loaded_model(
482
+ image_tensor,
483
+ get_cls_capt=True, # Get class caption as fallback
484
+ get_patch_capts=False,
485
+ get_avg_patch_capt=False
486
+ )
487
+
488
+ if 'cls_capt' in outputs:
489
+ return f"πŸ” No bounding boxes drawn. Image caption: {outputs['cls_capt']}"
490
+ else:
491
+ return "❌ Error: No bounding boxes detected. Please draw some bounding boxes on the image."
492
+
493
+ print(f"Processing {len(bboxes)} bounding boxes")
494
+
495
+ # Generate caption using the caption_bboxes method (as in eval_densecap.py)
496
+ try:
497
+ captions = loaded_model.caption_bboxes([image], [bboxes], crop_boxes=True)
498
+
499
+ if isinstance(captions, list) and captions:
500
+ if isinstance(captions[0], list):
501
+ captions = captions[0] # Unwrap nested list if needed
502
+ captions = [cap.replace("<|startoftext|>", "").replace("<|endoftext|>", "") for cap in captions]
503
+ # Join multiple captions if there are multiple bboxes
504
+ if len(captions) == 1:
505
+ return f"Generated Caption: {captions[0]}"
506
+ else:
507
+ formatted_captions = []
508
+ for i, caption in enumerate(captions, 1):
509
+ formatted_captions.append(f"BBox {i}: {caption}")
510
+ return "Generated Captions:\n" + "\n".join(formatted_captions)
511
+ elif isinstance(captions, str):
512
+ return f"Generated Caption: {captions}"
513
+ else:
514
+ return "❌ Error: No captions generated."
515
+
516
+ except Exception as e:
517
+ print(f"Error using caption_bboxes method: {e}")
518
+ # Fallback to regular forward method with bboxes
519
+ image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device)
520
+ bbox_tensor = torch.tensor([bboxes]).to(device)
521
+
522
+ with torch.no_grad():
523
+ outputs = loaded_model(
524
+ image_tensor,
525
+ bboxes=bbox_tensor,
526
+ get_cls_capt=False,
527
+ get_patch_capts=False,
528
+ get_avg_patch_capt=False
529
+ )
530
+
531
+ if 'bbox_capts' in outputs:
532
+ captions = outputs['bbox_capts']
533
+ if isinstance(captions, list) and captions:
534
+ if isinstance(captions[0], list):
535
+ captions = captions[0] # Unwrap nested list if needed
536
+ captions = [cap.replace("<|startoftext|>", "").replace("<|endoftext|>", "") for cap in captions]
537
+ if len(captions) == 1:
538
+ return f"Generated Caption: {captions[0]}"
539
+ else:
540
+ formatted_captions = []
541
+ for i, caption in enumerate(captions, 1):
542
+ formatted_captions.append(f"BBox {i}: {caption}")
543
+ return "Generated Captions:\n" + "\n".join(formatted_captions)
544
+ elif isinstance(captions, str):
545
+ return f"Generated Caption: {captions}"
546
+ else:
547
+ return "❌ Error: No captions generated."
548
+ else:
549
+ return "❌ Error: Model did not return bbox captions."
550
+
551
+ except Exception as e:
552
+ error_msg = f"❌ Error generating bbox caption: {str(e)}"
553
+ print(error_msg)
554
+ print(traceback.format_exc())
555
+ return error_msg
556
+
557
+
558
+ def create_gradio_interface():
559
+ """Create and configure the Gradio interface."""
560
+
561
+ # Get example files
562
+ example_images = get_example_images()
563
+ example_configs = get_example_configs()
564
+
565
+
566
+ custom_js = """
567
+ <script>
568
+ window.addEventListener("load", () => {
569
+ // Hide Crop, Erase, and Color buttons
570
+ const cropBtn = document.querySelector('.image-editor__tool[title="Crop"]');
571
+ const eraseBtn = document.querySelector('.image-editor__tool[title="Erase"]');
572
+ const colorBtn = document.querySelector('.image-editor__tool[title="Color"]');
573
+
574
+ [cropBtn, eraseBtn, colorBtn].forEach(btn => {
575
+ console.log("Going to disable display for ", btn);
576
+ if (btn) btn.style.display = "none";
577
+ });
578
+
579
+ // Optionally, select the Brush/Draft tool right away
580
+ const brushBtn = document.querySelector('.image-editor__tool[title="Draw"]');
581
+ console.log("Selecting brushbtn: ", brushBtn);
582
+ if (brushBtn) brushBtn.click();
583
+ });
584
+ </script>
585
+ """
586
+
587
+ with gr.Blocks(
588
+ title="Patchioner Trace Captioning Demo",
589
+ theme=gr.themes.Soft(),
590
+ css="""
591
+ .gradio-container {
592
+ max-width: 1200px !important;
593
+ }
594
+ """
595
+ ) as demo:
596
+ #gr.HTML(custom_js) # inject custom JS
597
+
598
+ gr.Markdown("""
599
+ # 🎯 Patchioner Trace Captioning Demo
600
+
601
+ This demo allows you to:
602
+ 1. **Select a captioning mode** (trace or bounding box)
603
+ 2. **Upload or select an image** from examples
604
+ 3. **Draw traces or bounding boxes** on the image
605
+ 4. **Generate captions** describing the marked areas
606
+
607
+ ## Instructions:
608
+ 1. Choose between Trace or BBox mode
609
+ 2. Upload an image or use one of the provided examples
610
+ 3. Use the appropriate tool to mark areas of interest in the image
611
+ 4. Click "Generate Caption" to get AI-generated descriptions
612
+
613
+ **Model:** Using `mlp.karpathy.yaml` configuration (automatically loaded)
614
+ """)
615
+
616
+ # Initialize model status
617
+ model_initialization_status = initialize_default_model()
618
+
619
+ with gr.Row():
620
+ gr.Markdown(f"**Model Status:** {model_initialization_status}")
621
+
622
+ with gr.Row():
623
+ mode_selector = gr.Radio(
624
+ choices=["trace", "bbox"],
625
+ value="trace",
626
+ label="πŸ“‹ Captioning Mode",
627
+ info="Choose between trace-based or bounding box-based captioning",
628
+ visible=False
629
+ )
630
+
631
+ with gr.Row():
632
+ with gr.Column():
633
+ gr.Markdown("### πŸ–ΌοΈ Image Editor")
634
+
635
+ # Image editor for drawing traces (default)
636
+ image_editor = gr.ImageEditor(
637
+ label="Upload image and draw traces",
638
+ type="pil",
639
+ crop_size=None,
640
+ brush=gr.Brush(default_size=3, colors=["red", "blue", "green", "yellow", "purple"]),
641
+ visible=True,
642
+ #tools=["brush"],
643
+ height=600
644
+ )
645
+
646
+ # Image annotator for bounding boxes (hidden by default)
647
+ image_annotator = foo_image_annotator( #gr.Image(
648
+ label="Upload image and draw bounding boxes",
649
+ visible=False,
650
+ #classes=["object"],
651
+ #type="bbox"
652
+ #tool="select"
653
+ )
654
+
655
+ with gr.Column():
656
+ if example_images:
657
+ gr.Markdown("#### πŸ“· Or select from example images:")
658
+ example_gallery = gr.Gallery(
659
+ value=example_images,
660
+ label="Example Images",
661
+ show_label=True,
662
+ elem_id="gallery",
663
+ columns=3,
664
+ rows=2,
665
+ object_fit="contain",
666
+ height="auto"
667
+ )
668
+
669
+ with gr.Row():
670
+ generate_button = gr.Button("✨ Generate Caption", variant="primary", size="lg")
671
+
672
+ with gr.Row():
673
+ output_text = gr.Textbox(
674
+ label="Generated Caption",
675
+ placeholder="Generated caption will appear here...",
676
+ lines=5,
677
+ max_lines=10,
678
+ interactive=False
679
+ )
680
+
681
+ # Event handlers
682
+ def toggle_input_components(mode):
683
+ """Toggle between image editor and annotator based on mode."""
684
+ if mode == "trace":
685
+ return gr.update(visible=True), gr.update(visible=False)
686
+ else: # bbox mode
687
+ return gr.update(visible=False), gr.update(visible=True)
688
+
689
+ def load_example_image_to_both(evt: gr.SelectData):
690
+ """Load selected example image into both components."""
691
+ try:
692
+ example_images = get_example_images()
693
+ if evt.index < len(example_images):
694
+ selected_image_path = example_images[evt.index]
695
+ img = Image.open(selected_image_path)
696
+ # For ImageEditor, return the PIL image directly
697
+ # For image_annotator, return dict format as expected by the component
698
+ annotated_format = {
699
+ "image": img,
700
+ "boxes": [],
701
+ "orientation": 0
702
+ }
703
+ return img, annotated_format
704
+ return None, {"image": None, "boxes": [], "orientation": 0}
705
+ except Exception as e:
706
+ print(f"Error loading example image: {e}")
707
+ return None, {"image": None, "boxes": [], "orientation": 0}
708
+
709
+ def generate_caption_wrapper(mode, image_editor_data, image_annotator_data):
710
+ """Wrapper to call generate_caption with the appropriate data based on mode."""
711
+ if mode == "trace":
712
+ return generate_caption(mode, image_editor_data)
713
+ else: # bbox mode
714
+ return generate_caption(mode, image_annotator_data)
715
+
716
+ # Connect event handlers
717
+ mode_selector.change(
718
+ fn=toggle_input_components,
719
+ inputs=mode_selector,
720
+ outputs=[image_editor, image_annotator]
721
+ )
722
+
723
+ generate_button.click(
724
+ fn=generate_caption_wrapper,
725
+ inputs=[mode_selector, image_editor, image_annotator],
726
+ outputs=output_text
727
+ )
728
+
729
+ if example_images:
730
+ example_gallery.select(
731
+ fn=load_example_image_to_both,
732
+ outputs=[image_editor, image_annotator]
733
+ )
734
+
735
+ gr.Markdown("""
736
+ ### πŸ’‘ Tips:
737
+ - **Mode Selection**: Switch between trace and bounding box modes based on your needs
738
+ - **Trace Mode**: Draw continuous lines over areas you want to describe
739
+ - **BBox Mode**: Draw rectangular bounding boxes around objects of interest
740
+ - **Multiple Areas**: Create multiple traces/boxes for different objects to get individual captions
741
+ - **Model Performance**: First load may take some time as weights are downloaded
742
+
743
+ ### πŸ”§ Technical Details:
744
+ - **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps
745
+ - **BBox Mode**: Uses bounding box coordinates for region-specific captioning
746
+ - **Model Architecture**: Uses `mlp.karpathy.yaml` configuration with CLIP and ViT components
747
+ - **Processing**: Each trace/bbox is processed separately to generate corresponding captions
748
+ """)
749
+
750
+ return demo
751
+
752
+
753
+ if __name__ == "__main__":
754
+ import argparse
755
+
756
+ parser = argparse.ArgumentParser(description="Patchioner Trace Captioning Demo")
757
+ parser.add_argument("--port", type=int, default=4141, help="Port to run the Gradio app on")
758
+ args = parser.parse_args()
759
+
760
+ print("Starting Patchioner Trace Captioning Demo...")
761
+ print(f"Using device: {device}")
762
+ print(f"Default model: {DEFAULT_MODEL_CONFIG}")
763
+ print(f"Example images directory: {EXAMPLE_IMAGES_DIR}")
764
+ print(f"Configs directory: {CONFIGS_DIR}")
765
+
766
+ demo = create_gradio_interface()
767
+ demo.launch(
768
+ server_name="0.0.0.0",
769
+ server_port=args.port,
770
+ share=True,
771
+ debug=True
772
+ )