import os import sys import cv2 import json import glob import argparse import subprocess from typing import List, Tuple, Dict, Any import numpy as np from tqdm import tqdm # ----------------- Args ----------------- def parse_args(): ap = argparse.ArgumentParser("OWLv2 detection on JPG folders (Top-K per image), multi-GPU.") ap.add_argument("--input_dir", type=str, required=True, help="Root that contains subfolders of JPGs; if JPGs are directly under input_dir, it will be treated as a single set.") ap.add_argument("--startswith", type=str, default="", help="Filter folder name prefix (or input_dir basename if no subfolders).") ap.add_argument("--output_dir", type=str, required=True) ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N-th image within a folder.") ap.add_argument("--top_k", type=int, default=5) ap.add_argument("--max_frames", type=int, default=0, help="Max processed images per folder; 0 means no limit.") ap.add_argument("--num_workers", type=int, default=1, help="#GPUs/#workers") ap.add_argument("--worker_idx", type=int, default=-1, help="internal; >=0 means child worker") ap.add_argument("--shard_file", type=str, default="", help="internal; JSON with folder paths for this worker") ap.add_argument("--scenic_root", type=str, default="/home/ubuntu/rs/JiT/VisionModels/Scenic_OWLv2/big_vision") return ap.parse_args() # ----------------- Utils ----------------- def _has_jpgs(path: str) -> bool: exts = ("*.jpg", "*.jpeg", "*.JPG", "*.JPEG") for pat in exts: if glob.glob(os.path.join(path, pat)): return True return False def iter_image_dirs(input_dir: str, startswith: str) -> List[str]: """ Returns a list of directories to process. - If input_dir contains subfolders: return subfolders that contain JPGs and match startswith. - Else if input_dir itself contains JPGs and its basename matches startswith: return [input_dir]. """ input_dir = os.path.abspath(input_dir) subs = sorted([p for p in glob.glob(os.path.join(input_dir, "*")) if os.path.isdir(p)]) # Prefer subfolders if any exist and contain jpgs dirs = [d for d in subs if os.path.basename(d).startswith(startswith) and _has_jpgs(d)] if dirs: return dirs # Fallback: treat input_dir itself as one set if it has jpgs base_ok = os.path.basename(os.path.normpath(input_dir)).startswith(startswith) if base_ok and _has_jpgs(input_dir): return [input_dir] return [] def ensure_dir(p: str): os.makedirs(p, exist_ok=True) def draw_single_box(frame_bgr: np.ndarray, box: List[float], color=(0, 255, 0), thickness=2) -> np.ndarray: x1, y1, x2, y2 = map(int, box) out = frame_bgr.copy() cv2.rectangle(out, (x1, y1), (x2, y2), color, thickness) return out def list_images_sorted(folder: str) -> List[str]: pats = ["*.jpg", "*.jpeg", "*.JPG", "*.JPEG"] files = [] for pat in pats: files.extend(glob.glob(os.path.join(folder, pat))) # Sort by natural file name order return sorted(files) # ----------------- Worker logic (imports JAX/Scenic inside) ----------------- def worker_run(args, dir_paths: List[str]): import sys as _sys if args.scenic_root not in _sys.path: _sys.path.append(args.scenic_root) # Free TF GPU to JAX in this process (why: avoid TF reserving VRAM) import tensorflow as tf tf.config.experimental.set_visible_devices([], "GPU") from scenic.projects.owl_vit import configs from scenic.projects.owl_vit import models import jax import functools import owlv2_helper as helper # must be available in PYTHONPATH class OWLv2Objectness: def __init__(self, top_k: int = 5): self.top_k = top_k self.config = configs.owl_v2_clip_b16.get_config(init_mode="canonical_checkpoint") self.module = models.TextZeroShotDetectionModule( body_configs=self.config.model.body, objectness_head_configs=self.config.model.objectness_head, normalize=self.config.model.normalize, box_bias=self.config.model.box_bias, ) self.variables = self.module.load_variables(self.config.init_from.checkpoint_path) self.image_embedder = jax.jit( functools.partial(self.module.apply, self.variables, train=False, method=self.module.image_embedder) ) self.objectness_predictor = jax.jit( functools.partial(self.module.apply, self.variables, method=self.module.objectness_predictor) ) self.box_predictor = jax.jit( functools.partial(self.module.apply, self.variables, method=self.module.box_predictor) ) def detect(self, image_bgr: np.ndarray) -> List[Tuple[List[float], float]]: image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) processed = helper.preprocess_images([image_rgb], self.config.dataset_configs.input_size)[0] feature_map = self.image_embedder(processed[None, ...]) b, h, w, d = feature_map.shape image_features = feature_map.reshape(b, h * w, d) obj_logits = self.objectness_predictor(image_features)["objectness_logits"] raw_boxes = self.box_predictor(image_features=image_features, feature_map=feature_map)["pred_boxes"] obj = np.array(obj_logits[0], dtype=np.float32) raw_boxes = np.array(raw_boxes[0], dtype=np.float32) boxes = helper.rescale_detection_box(raw_boxes, image_rgb) if len(obj) == 0: return [] k = min(self.top_k, len(obj)) thresh = np.partition(obj, -k)[-k] filtered: List[Tuple[List[float], float]] = [] H, W = image_rgb.shape[:2] for box, score in zip(boxes, obj): if score < thresh: continue if helper.too_small(box) or helper.too_large(box, image_rgb): continue x1, y1, x2, y2 = box x1 = max(0, min(float(x1), W - 1)) y1 = max(0, min(float(y1), H - 1)) x2 = max(0, min(float(x2), W - 1)) y2 = max(0, min(float(y2), H - 1)) filtered.append(([x1, y1, x2, y2], float(score))) kept_boxes = helper.remove_overlapping_bboxes([b for b, _ in filtered]) def _match_score(bb: List[float]) -> float: arr = np.array([b for b, _ in filtered], dtype=np.float32) idx = int(np.argmin(np.abs(arr - np.array(bb, dtype=np.float32)).sum(axis=1))) return filtered[idx][1] return [(bb, _match_score(bb)) for bb in kept_boxes] detector = OWLv2Objectness(top_k=args.top_k) for dpath in tqdm(dir_paths, desc=f"Worker{args.worker_idx}", unit="set"): stem = os.path.basename(os.path.normpath(dpath)) images = list_images_sorted(dpath) if not images: print(f"[WARN][w{args.worker_idx}] No JPGs under: {dpath}") continue saved_cnt = 0 pbar = tqdm(total=len(images), desc=f"{stem}[w{args.worker_idx}]", unit="img", leave=False) for idx, ipath in enumerate(images): pbar.update(1) if args.frame_stride > 1 and (idx % args.frame_stride) != 0: continue frame = cv2.imread(ipath, cv2.IMREAD_COLOR) if frame is None: print(f"[WARN][w{args.worker_idx}] Cannot read: {ipath}") continue boxes_scores = detector.detect(frame) if boxes_scores: boxes_scores = sorted(boxes_scores, key=lambda x: x[1], reverse=True)[:args.top_k] fname = os.path.basename(ipath) for i, (box, score) in enumerate(boxes_scores): out_dir = os.path.join(args.output_dir, stem, f"object_{i}") ensure_dir(out_dir) vis = draw_single_box(frame, box, color=(0, 255, 0), thickness=2) cv2.imwrite(os.path.join(out_dir, fname), vis) saved_cnt += 1 if args.max_frames and saved_cnt >= args.max_frames: break pbar.close() # ----------------- Master ----------------- def main(): args = parse_args() # Child worker path if args.worker_idx >= 0: if not args.shard_file or not os.path.exists(args.shard_file): raise RuntimeError("Worker requires --shard_file with JSON list of folder paths.") with open(args.shard_file, "r", encoding="utf-8") as f: dir_paths = json.load(f) worker_run(args, dir_paths) return # Master path dir_paths = iter_image_dirs(args.input_dir, args.startswith) if not dir_paths: print(f"[INFO] No JPG folders (or JPGs) startwith '{args.startswith}' under {args.input_dir}") return num_workers = max(1, int(args.num_workers)) shards: List[List[str]] = [[] for _ in range(num_workers)] for i, d in enumerate(dir_paths): shards[i % num_workers].append(d) procs = [] tmp_dir = os.path.join(args.output_dir, "_shards_tmp") ensure_dir(tmp_dir) for w in range(num_workers): shard_path = os.path.join(tmp_dir, f"shard_{w}.json") with open(shard_path, "w", encoding="utf-8") as f: json.dump(shards[w], f, ensure_ascii=False, indent=0) # Bind GPU: cycle through available GPU ids [0..num_workers-1] env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(w) # one GPU per worker cmd = [ sys.executable, __file__, "--input_dir", args.input_dir, "--startswith", args.startswith, "--output_dir", args.output_dir, "--frame_stride", str(args.frame_stride), "--top_k", str(args.top_k), "--max_frames", str(args.max_frames), "--num_workers", str(num_workers), "--worker_idx", str(w), "--shard_file", shard_path, "--scenic_root", args.scenic_root, ] print(f"[Master] Launch worker {w}: GPU={env['CUDA_VISIBLE_DEVICES']} folders={len(shards[w])}") procs.append(subprocess.Popen(cmd, env=env)) # wait rc = 0 for p in procs: p.wait() rc |= p.returncode if rc != 0: print("[Master] Some workers failed. Return code:", rc) else: print("[Master] All workers done. Output:", args.output_dir) if __name__ == "__main__": main()