histOSM / inference_tab /helpers.py
muk42's picture
fix edge tiles
50c28f0
from ultralytics import SAM
import cv2
from shapely.geometry import shape
from rapidfuzz import process, fuzz
from huggingface_hub import hf_hub_download
from config import OUTPUT_DIR
from pathlib import Path
from PIL import Image
import spaces
import numpy as np
import os
import json
from PIL import Image
def box_inside_global(box, global_box):
x1, y1, x2, y2 = box
gx1, gy1, gx2, gy2 = global_box
return (x1 >= gx1 and y1 >= gy1 and x2 <= gx2 and y2 <= gy2)
def nms_iou(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0
def non_max_suppression(boxes, scores, iou_threshold=0.5):
idxs = np.argsort(scores)[::-1]
keep = []
while len(idxs) > 0:
current = idxs[0]
keep.append(current)
idxs = idxs[1:]
idxs = np.array([i for i in idxs if nms_iou(boxes[current], boxes[i]) < iou_threshold])
return keep
def tile_image_with_overlap(image_path, tile_size=1024, overlap=256):
"""Tile image into overlapping RGB tiles."""
image = cv2.imread(image_path)
height, width, _ = image.shape
step = tile_size - overlap
tile_list = []
seen = set() # to avoid duplicates
for y in range(0, height, step):
if y + tile_size > height:
y = height - tile_size
for x in range(0, width, step):
if x + tile_size > width:
x = width - tile_size
# clamp to valid region
x_start = max(0, x)
y_start = max(0, y)
x_end = x_start + tile_size
y_end = y_start + tile_size
coords = (x_start, y_start)
if coords in seen: # skip duplicates
continue
seen.add(coords)
tile = image[y_start:y_end, x_start:x_end, :]
tile_list.append((tile, coords))
return tile_list, image.shape
def compute_iou(box1, box2):
"""Compute Intersection over Union for two boxes."""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = area1 + area2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def merge_boxes(boxes, iou_threshold=0.8):
"""Merge overlapping boxes based on IoU."""
merged = []
used = [False] * len(boxes)
for i, box in enumerate(boxes):
if used[i]:
continue
group = [box]
used[i] = True
for j in range(i + 1, len(boxes)):
if used[j]:
continue
if compute_iou(box, boxes[j]) > iou_threshold:
group.append(boxes[j])
used[j] = True
# Merge group into one bounding box
x1 = min(b[0] for b in group)
y1 = min(b[1] for b in group)
x2 = max(b[2] for b in group)
y2 = max(b[3] for b in group)
merged.append([x1, y1, x2, y2])
return merged
def box_area(box):
return max(0, box[2] - box[0]) * max(0, box[3] - box[1])
def is_contained(box1, box2, containment_threshold=0.9):
# Check if box1 is mostly inside box2
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
area1 = box_area(box1)
area2 = box_area(box2)
# If intersection covers most of smaller box area, consider contained
smaller_area = min(area1, area2)
if smaller_area == 0:
return False
return (inter_area / smaller_area) >= containment_threshold
def merge_boxes_iterative(boxes, iou_threshold=0.25, containment_threshold=0.75):
boxes = boxes.copy()
changed = True
while changed:
changed = False
merged = []
used = [False] * len(boxes)
for i, box in enumerate(boxes):
if used[i]:
continue
group = [box]
used[i] = True
for j in range(i + 1, len(boxes)):
if used[j]:
continue
iou = compute_iou(box, boxes[j])
contained = is_contained(box, boxes[j], containment_threshold)
if iou > iou_threshold or contained:
group.append(boxes[j])
used[j] = True
# Merge group into one bounding box
x1 = min(b[0] for b in group)
y1 = min(b[1] for b in group)
x2 = max(b[2] for b in group)
y2 = max(b[3] for b in group)
merged.append([x1, y1, x2, y2])
if len(merged) < len(boxes):
changed = True
boxes = merged
return boxes
def get_corner_points(box):
x1, y1, x2, y2 = box
return [
[x1, y1], # top-left
[x2, y1], # top-right
[x1, y2], # bottom-left
[x2, y2], # bottom-right
]
def sample_negative_points_outside_boxes(mask, num_points):
points = []
tries = 0
max_tries = num_points * 20 # fail-safe to avoid infinite loops
while len(points) < num_points and tries < max_tries:
x = np.random.randint(0, mask.shape[1])
y = np.random.randint(0, mask.shape[0])
if not mask[y, x]:
points.append([x, y])
tries += 1
return np.array(points)
def get_inset_corner_points(box, margin=5):
x1, y1, x2, y2 = box
# Ensure box is large enough for the margin
x1i = min(x1 + margin, x2)
y1i = min(y1 + margin, y2)
x2i = max(x2 - margin, x1)
y2i = max(y2 - margin, y1)
return [
[x1i, y1i], # top-left (inset)
[x2i, y1i], # top-right
[x1i, y2i], # bottom-left
[x2i, y2i], # bottom-right
]
def processYOLOBoxes(iou):
# Load YOLO-predicted boxes
BOXES_PATH = os.path.join(OUTPUT_DIR,"boxes.json")
with open(BOXES_PATH, "r") as f:
box_data = json.load(f)
# Non-max suppression
boxes = np.array([item["bbox"] for item in box_data])
scores = np.array([item["score"] for item in box_data])
# Run NMS
keep_indices = non_max_suppression(boxes, scores, iou)
# Filter data
box_data = [box_data[i] for i in keep_indices]
# Filter boxes inside global bbox (TBD)
#box_data = [entry for entry in box_data if box_inside_global(entry["bbox"], GLOBAL_BOX)]
boxes_full = [b["bbox"] for b in box_data] # Format: [x1, y1, x2, y2]
return boxes_full
def prepare_tiles(image_path, boxes_full, tile_size=1024, overlap=50, iou=0.5, c_th=0.75, edge_margin=10):
"""
Tiles the image and prepares per-tile metadata including filtered boxes and point prompts.
Returns full image size H, W.
"""
tiles, (H, W, _) = tile_image_with_overlap(image_path, tile_size, overlap)
os.makedirs("tmp/tiles", exist_ok=True)
meta = []
for idx, (tile_array, (x_offset, y_offset)) in enumerate(tiles):
tile_path = f"tmp/tiles/tile_{idx}.png"
tile_array = cv2.cvtColor(tile_array, cv2.COLOR_BGR2RGB)
Image.fromarray(tile_array).save(tile_path)
tile_h, tile_w, _ = tile_array.shape
# Select boxes overlapping this tile
candidate_boxes = []
for x1, y1, x2, y2 in boxes_full:
if (x2 > x_offset) and (x1 < x_offset + tile_w) and (y2 > y_offset) and (y1 < y_offset + tile_h):
candidate_boxes.append([x1, y1, x2, y2])
if not candidate_boxes:
meta.append({
"idx": idx,
"x_off": x_offset,
"y_off": y_offset,
"local_boxes": [],
"point_coords": [],
"point_labels": []
})
continue
# Merge overlapping boxes
merged_boxes = merge_boxes_iterative(candidate_boxes, iou_threshold=iou, containment_threshold=c_th)
# Adjust boxes to tile-local coordinates
local_boxes = []
for x1, y1, x2, y2 in merged_boxes:
new_x1 = max(0, x1 - x_offset)
new_y1 = max(0, y1 - y_offset)
new_x2 = min(tile_w, x2 - x_offset)
new_y2 = min(tile_h, y2 - y_offset)
local_boxes.append([new_x1, new_y1, new_x2, new_y2])
# Filter boxes too close to edges
filtered_local_boxes = []
for box in local_boxes:
x1, y1, x2, y2 = box
if (x1 > edge_margin and y1 > edge_margin and (tile_w - x2) > edge_margin and (tile_h - y2) > edge_margin):
filtered_local_boxes.append(box)
if not filtered_local_boxes:
meta.append({
"idx": idx,
"x_off": x_offset,
"y_off": y_offset,
"local_boxes": [],
"point_coords": [],
"point_labels": []
})
continue
# Compute point prompts
centroids = [((bx1 + bx2) / 2, (by1 + by2) / 2) for bx1, by1, bx2, by2 in filtered_local_boxes]
negative_points_per_box = [get_inset_corner_points(box, margin=2) for box in filtered_local_boxes]
point_coords = []
point_labels = []
for centroid, neg_points in zip(centroids, negative_points_per_box):
if not isinstance(neg_points, list):
neg_points = neg_points.tolist()
all_points = [centroid] + neg_points
all_labels = [1] + [0] * len(neg_points)
point_coords.append(all_points)
point_labels.append(all_labels)
meta.append({
"idx": idx,
"x_off": x_offset,
"y_off": y_offset,
"local_boxes": filtered_local_boxes,
"point_coords": point_coords,
"point_labels": point_labels
})
# Save metadata
os.makedirs("tmp", exist_ok=True)
with open("tmp/tiles_meta.json", "w") as f:
json.dump(meta, f)
return H, W
def merge_tile_masks(H, W):
"""
Merge predicted tile masks into a full-size image.
Args:
H (int): full image height
W (int): full image width
Returns:
full_mask (np.ndarray): merged mask array
"""
full_mask = np.zeros((H, W), dtype=np.uint16)
instance_id = 1
# Load tile metadata
with open("tmp/tiles_meta.json", "r") as f:
tiles_meta = json.load(f)
for tile in tiles_meta:
tile_idx = tile["idx"]
x_off = tile["x_off"]
y_off = tile["y_off"]
mask_path = f"tmp/masks/tile_{tile_idx}.npy"
if not Path(mask_path).exists():
continue
# Load tile masks (expected shape = (N, h, w))
tile_masks = np.load(mask_path)
if tile_masks.ndim == 2: # single mask saved as (h, w)
tile_masks = tile_masks[None, :, :] # make it (1, h, w)
for mask in tile_masks:
mask = mask.astype(bool)
# Pad mask to 1024x1024
pad_h = 1024 - mask.shape[0]
pad_w = 1024 - mask.shape[1]
if pad_h > 0 or pad_w > 0:
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
h_end = min(y_off + mask.shape[0], H)
w_end = min(x_off + mask.shape[1], W)
region = full_mask[y_off:h_end, x_off:w_end]
mask = mask[:h_end - y_off, :w_end - x_off]
region[mask & (region == 0)] = instance_id
instance_id += 1
# Save as TIFF
final_mask = Image.fromarray(full_mask)
MASK_PATH = os.path.join(OUTPUT_DIR,"mask.tif")
final_mask.save(MASK_PATH)
def chunkify(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def img_shape(image_path):
img = cv2.imread(image_path)
return img.shape
def best_street_match(point, query_name, edges_gdf, max_distance=100):
buffer = point.buffer(max_distance)
nearby_edges = edges_gdf[edges_gdf.intersects(buffer)]
if nearby_edges.empty:
return None, 0
candidate_names = nearby_edges['name'].tolist()
best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
return best_match # (name, score, index)