import spaces import numpy as np from ultralytics import YOLO import os import json from PIL import Image from ultralytics import SAM import cv2 import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel import rasterio import rasterio.features from shapely.geometry import shape import pandas as pd import osmnx as ox from osgeo import gdal, osr import geopandas as gpd from rapidfuzz import process, fuzz from huggingface_hub import hf_hub_download from config import OUTPUT_DIR from pathlib import Path from PIL import Image from .helpers import box_inside_global,nms_iou,non_max_suppression,tile_image_with_overlap,compute_iou,merge_boxes,box_area,is_contained,merge_boxes_iterative,get_corner_points,sample_negative_points_outside_boxes,get_inset_corner_points,processYOLOBoxes,prepare_tiles,merge_tile_masks,chunkify,img_shape,best_street_match from pyproj import Transformer import shutil import re from shapely.ops import nearest_points from geopy.distance import geodesic # Global cache _trocr_processor = None _trocr_model = None _trocr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def run_inference(tile_dict, gcp_path,user_crs, city_name, score_th, hist_th, hist_dic): IMAGE_FOLDER = os.path.join(OUTPUT_DIR, "blobs") CSV_FILE = os.path.join(OUTPUT_DIR, "annotations.csv") MASK_FILE = os.path.join(OUTPUT_DIR, "mask.tif") if os.path.exists(IMAGE_FOLDER): shutil.rmtree(IMAGE_FOLDER) os.makedirs(IMAGE_FOLDER, exist_ok=True) if os.path.exists("tmp"): shutil.rmtree("tmp") os.makedirs("tmp", exist_ok=True) if os.path.exists(CSV_FILE): os.remove(CSV_FILE) if os.path.exists(MASK_FILE): os.remove(MASK_FILE) log = "" if tile_dict is None: yield "No tile selected", None return image_path = tile_dict["tile_path"] coords = tile_dict["coords"] # (x_start, y_start, x_end, y_end) print(f"Tile path: {image_path}; Tile coords: {coords}") # ==== TEXT DETECTION ==== for msg in getBBoxes(image_path): log += msg + "\n" yield log, None for msg in getSegments(image_path): if msg.endswith(".tif"): log += f"Mask saved at {msg}.\n" yield log, msg else: log += msg + "\n" yield log, None if "No labels detected" in msg: stop_pipeline = True break else: stop_pipeline=False if stop_pipeline: yield log + "Pipeline stopped: no text segments found.\n", None return for msg in extractSegments(image_path): log += msg + "\n" yield log, None # === TEXT RECOGNITION === for msg in blobsOCR_all(): log += msg + "\n" yield log, None # === ADD GEO DATA === for msg in georefTile(coords,gcp_path): log += msg + "\n" yield log, None for msg in extractCentroids(image_path): log += msg + "\n" yield log, None for msg in extractStreetNet(city_name, user_crs): log += msg + "\n" yield log, None # === POST OCR === all_csvs = [] for msg in fuzzyMatch(score_th, tile_dict): if isinstance(msg, list): # msg is [street_matches_csv, osm_csv] all_csvs.extend(msg) # append these CSV paths log += "Finished! CSVs saved at:\n" for f in msg: log += f" - {f}\n" yield log, None else: log += msg + "\n" yield log, None if hist_dic is not None: # Run fuzzy match against historic street names for msg in fuzzyMatchHist(hist_dic, hist_th,tile_dict): if isinstance(msg, list): all_csvs.extend(msg) # append historic CSV as well log += "Historic fuzzy matching finished! CSVs saved at:\n" for f in msg: log += f" - {f}\n" yield log, all_csvs # now yields all three CSVs together else: log += msg + "\n" yield log, None else: # If historic matching is skipped, yield the OSM match files yield log, all_csvs def load_trocr_model(): """Load TrOCR into GPU if not cached.""" global _trocr_processor, _trocr_model if _trocr_model is None: _trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str") _trocr_model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets") _trocr_model.to(_trocr_device).eval() return _trocr_processor, _trocr_model @spaces.GPU def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25): yield f"DEBUG: Received image_path: {image_path}" image = cv2.imread(image_path) H, W, _ = image.shape yolo_weights = hf_hub_download( repo_id="muk42/yolov9_streets", filename="yolov9c_finetuned_v2.pt") # fine-tuned on selection of city maps model = YOLO(yolo_weights) step = int(tile_size * (1 - overlap)) all_detections=[] total_tiles = 0 # Calculate total tiles for progress reporting for y in range(0, H, step): for x in range(0, W, step): # Skip small tiles at the edges if y + tile_size > H or x + tile_size > W: continue total_tiles += 1 processed_tiles = 0 # Tile the image and run prediction for y in range(0, H, step): for x in range(0, W, step): tile = image[y:y+tile_size, x:x+tile_size] if tile.shape[0] < tile_size or tile.shape[1] < tile_size: continue results= model.predict(source=tile, imgsz=tile_size, conf=confidence_threshold, verbose=False, iou=0.5) for result in results: boxes = result.boxes.xyxy.cpu().numpy() scores = result.boxes.conf.cpu().numpy() classes = result.boxes.cls.cpu().numpy() for box, score, cls in zip(boxes, scores, classes): x1, y1, x2, y2 = box # Shift box coordinates relative to full image x1 += x x2 += x y1 += y y2 += y all_detections.append([x1, y1, x2, y2, float(score), int(cls)]) processed_tiles += 1 yield f"Processed tile {processed_tiles} of {total_tiles}" # After all tiles are processed, save detections to JSON boxes_to_save = [ { "bbox": [float(x1), float(y1), float(x2), float(y2)], "score": float(conf), "class": int(cls) } for x1, y1, x2, y2, conf, cls in all_detections ] BOXES_PATH = os.path.join(OUTPUT_DIR,"boxes.json") with open(BOXES_PATH, "w") as f: json.dump(boxes_to_save, f, indent=4) yield f"Inference complete." @spaces.GPU def run_tile_inference(): model = SAM("mobile_sam.pt") # sam2.1_l.pt Path("tmp/masks").mkdir(parents=True, exist_ok=True) with open("tmp/tiles_meta.json", "r") as f: tiles_meta = json.load(f) for tile in tiles_meta: yield f"Processing {tile['idx']}..." tile_path = f"tmp/tiles/tile_{tile['idx']}.png" out_path = tile_path.replace("tiles", "masks").replace(".png", ".npy") # skip if already processed if Path(out_path).exists(): continue local_boxes = tile.get('local_boxes', []) point_coords = tile.get('point_coords', []) point_labels = tile.get('point_labels', []) tile_array = np.array(Image.open(tile_path)) # If there are no boxes and no labels, stop execution if not local_boxes and not point_coords and not point_labels: yield f"Tile {tile['idx']} has no boxes or points/labels. Stopping inference." return results = model(tile_array, bboxes=local_boxes, points=point_coords, labels=point_labels) masks_to_save = [r.masks.data.cpu().numpy() for r in results if r.masks is not None] if masks_to_save: masks_stack = np.concatenate(masks_to_save, axis=0) # shape (N, H, W) np.save(out_path, masks_stack) def getSegments(image_path,iou=0.5,c_th=0.75,edge_margin=10): """ iou for combining bounding boxes c_th defined share of the smaller box contained in the larger box for merge edge_margin pixel margin for tiles """ yield "Load YOLO boxes.." BOXES_PATH = os.path.join(OUTPUT_DIR,"boxes.json") with open(BOXES_PATH, "r") as f: box_data = json.load(f) boxes = [b["bbox"] for b in box_data] yield "Prepare tiles..." H,W = prepare_tiles(image_path, boxes, tile_size=1024, overlap=50, iou=iou, c_th=c_th, edge_margin=edge_margin) yield "Run inference on tiles..." for msg in run_tile_inference(): yield msg if "Stopping inference" in msg: yield "No labels detected – halting getSegments." return yield "Marge predicted masks into image..." merge_tile_masks(H,W) MASK_PATH = os.path.join(OUTPUT_DIR,"mask.tif") yield f"{MASK_PATH}" def extractSegments(image_path, min_size=500, margin=100): image = cv2.imread(image_path) MASK_PATH = os.path.join(OUTPUT_DIR, "mask.tif") mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED) height, width = mask.shape[:2] # Get unique labels (excluding background label 0) blob_ids = np.unique(mask) blob_ids = blob_ids[blob_ids != 0] yield f"Found {len(blob_ids)} blobs" for blob_id in blob_ids: yield f"Processing blob {blob_id}..." # Create a binary mask for the current blob blob_mask = (mask == blob_id).astype(np.uint8) # Skip small blobs if np.sum(blob_mask) < min_size: continue # Find bounding box of the blob (tight box without margins) ys, xs = np.where(blob_mask) y_min, y_max = ys.min(), ys.max() + 1 x_min, x_max = xs.min(), xs.max() + 1 # ---- ORIGINAL (no mask, no margin) ---- cropped_image_orig = image[y_min:y_max, x_min:x_max] BLOB_PATH_ORIG = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}.png") cv2.imwrite(BLOB_PATH_ORIG, cropped_image_orig) # ---- MARGINALIZED (with mask/shading) ---- x_min_m = max(0, x_min - margin) y_min_m = max(0, y_min - margin) x_max_m = min(width, x_max + margin) y_max_m = min(height, y_max + margin) cropped_image_margin = image[y_min_m:y_max_m, x_min_m:x_max_m] cropped_mask_margin = blob_mask[y_min_m:y_max_m, x_min_m:x_max_m] shaded_margin = cropped_image_margin.copy() overlay_margin = cropped_image_margin.copy() overlay_margin[cropped_mask_margin == 1] = (255, 200, 100) shaded_margin = cv2.addWeighted(overlay_margin, 0.35, shaded_margin, 0.65, 0) BLOB_PATH_MARGIN = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}_margin.png") cv2.imwrite(BLOB_PATH_MARGIN, shaded_margin) yield f"Done." '''@spaces.GPU(duration=180) def blobsOCR(image_path): yield "Load OCR model.." # Load model + processor processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str") model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets") image_extensions = (".png") # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.half().to(device) # float16 weights precision yield f"Running on {device}..." # Open output file for writing OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv") with open(OCR_PATH, "w", encoding="utf-8") as f_out: # Process each image image_folder = os.path.join(OUTPUT_DIR,"blobs") for filename in os.listdir(image_folder): if filename.lower().endswith(image_extensions): image_path = os.path.join(image_folder, filename) try: image = Image.open(image_path).convert("RGB") pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Write to file name = os.path.splitext(os.path.basename(filename))[0] f_out.write(f'{name},"{generated_text}"\n') yield f"{filename} → {generated_text}" except Exception as e: yield f"Error processing {filename}: {e}"''' @spaces.GPU def blobsOCR_chunk(image_paths): """Run OCR on a list of images (one chunk).""" processor, model = load_trocr_model() results = [] # Load all images in the chunk images = [Image.open(path).convert("RGB") for path in image_paths] # Convert to pixel_values tensor pixel_values = processor(images=images, return_tensors="pt", padding=True).pixel_values.to(_trocr_device) # Generate text for the whole batch at once generated_ids = model.generate(pixel_values) texts = processor.batch_decode(generated_ids, skip_special_tokens=True) for path, text in zip(image_paths, texts): name = os.path.splitext(os.path.basename(path))[0] results.append((name, text)) return results def blobsOCR_all(): image_folder = os.path.join(OUTPUT_DIR, "blobs") all_files = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(".png") and '_margin' not in f] OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv") with open(OCR_PATH, "w", encoding="utf-8") as f_out: for chunk in chunkify(all_files, n=16): # adjust batch size yield f"Processing {len(chunk)} images..." results = blobsOCR_chunk(chunk) for name, text in results: f_out.write(f'{name},"{text}"\n') yield f"{name} → {text}" def extractCentroids(image_path): GEO_PATH=os.path.join(OUTPUT_DIR,"mask_georef.tif") with rasterio.open(GEO_PATH) as src: mask = src.read(1) transform = src.transform labels = np.unique(mask) labels = labels[labels != 0] data = [] # Generate polygons and their values shapes_gen = rasterio.features.shapes(mask, mask=(mask != 0), transform=transform) # Create a dict to collect polygons by label polygons_by_label = {} for geom, val in shapes_gen: if val == 0: continue polygons_by_label.setdefault(val, []).append(shape(geom)) # For each label, merge polygons and get centroid for idx, label in enumerate(labels): yield f"Processing {idx+1} out of {len(labels)}" polygons = polygons_by_label.get(label) if not polygons: continue # Merge polygons of the same label (if multiple parts) multi_poly = polygons[0] for poly in polygons[1:]: multi_poly = multi_poly.union(poly) centroid = multi_poly.centroid data.append({"blob_id": label, "x": centroid.x, "y": centroid.y}) df = pd.DataFrame(data) COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv") df.to_csv(COORD_PATH, index=False) yield f"Saved centroid coordinates of {len(labels)} blobs." def georefTile(tile_coords, gcp_path): yield "Georeferencing SAM image.." MASK_TILE=os.path.join(OUTPUT_DIR,"mask.tif") TMP_TILE=os.path.join(OUTPUT_DIR,"mask_tmp.tif") MASK_TILE_GEO=os.path.join(OUTPUT_DIR,"mask_georef.tif") for f in [TMP_TILE, MASK_TILE_GEO]: if os.path.exists(f): os.remove(f) df = pd.read_csv(gcp_path) xmin, ymin, xmax, ymax = tile_coords xoff, yoff = xmin, ymin xsize, ysize = xmax - xmin, ymax - ymin shifted_gcps = [] for _, r in df.iterrows(): shifted_gcps.append( gdal.GCP( float(r['mapX']), float(r['mapY']), 0, float(r['sourceX']) - xoff, abs(float(r['sourceY'])) - yoff ) ) gdal.Translate( TMP_TILE, MASK_TILE, format="GTiff", GCPs=shifted_gcps, outputSRS="EPSG:3857" ) gdal.Warp( MASK_TILE_GEO, TMP_TILE, dstSRS="EPSG:3857", resampleAlg="near", polynomialOrder=1, creationOptions=["COMPRESS=LZW"] ) yield "Done." def georefImg(image_path, gcp_path, user_crs): TMP_FILE = os.path.join(OUTPUT_DIR,"tmp.tif") GEO_FILE = os.path.join(OUTPUT_DIR,"georeferenced.tif") VRT_FILE = os.path.join(OUTPUT_DIR,"vrt_file.vrt") for f in [TMP_FILE, GEO_FILE]: if os.path.exists(f): os.remove(f) yield "Read GCP points..." df = pd.read_csv(gcp_path) H,W,_ = img_shape(image_path) # Build GCPs gcps = [] '''for _, r in df.iterrows(): gcps.append( gdal.GCP( float(r['mapX']), float(r['mapY']), 0, float(r['sourceX']), #H-float(r['sourceY']) abs(float(r['sourceY'])) ) )''' for _, r in df.iterrows(): gcps.append(( float(r['mapX']), float(r['mapY']), float(r['sourceX']), #H-float(r['sourceY']) abs(float(r['sourceY'])) )) # OLD '''gdal.Translate( TMP_FILE, image_path, format="GTiff", GCPs=gcps, outputSRS="EPSG:3857" )''' yield "Transform GCP to user specified CRS..." # Transform GCP from user provided CRS to Web Mercator 3857 transformer=Transformer.from_crs(f"epsg:{user_crs}","epsg:3857",always_xy=True) gcps3857=[] for px,py,x,y in gcps: x3857,y3857=transformer.transform(px,py) gcp=gdal.GCP(x3857,y3857,0,x,y) gcps3857.append(gcp) yield "Apply GCP to the image..." # Apply GCP to the image src_ds=gdal.Open(image_path) drv=gdal.GetDriverByName('VRT') vrt_ds=drv.CreateCopy(VRT_FILE,src_ds,0) # Set the GCPs and spatial reference system srs3857=osr.SpatialReference() srs3857.ImportFromEPSG(3857) vrt_ds.SetGCPs(gcps3857,srs3857.ExportToWkt()) vrt_ds=None # close vrt to save changes gdal.Warp( GEO_FILE, VRT_FILE, # TMP_FILE, dstSRS="EPSG:3857", resampleAlg="near", polynomialOrder=1, creationOptions=["COMPRESS=LZW"], format='GTiff' ) yield "The map is georeferenced." def extractStreetNet(city_name,user_crs): yield f"Extract OSM street network for {city_name}" MASK_TILE_GEO=os.path.join(OUTPUT_DIR,"mask_georef.tif") ds = gdal.Open(MASK_TILE_GEO) gt = ds.GetGeoTransform() width = ds.RasterXSize height = ds.RasterYSize minx = gt[0] maxy = gt[3] maxx = gt[0] + width * gt[1] + height * gt[2] miny = gt[3] + width * gt[4] + height * gt[5] # Add 100 meters buffer in all directions minx -= 100 # west maxx += 100 # east miny -= 100 # south maxy += 100 # north bbox = (maxy, miny, maxx, minx) transformer = Transformer.from_crs(f"EPSG:{user_crs}", "EPSG:4326", always_xy=True) north, south = transformer.transform(bbox[2], bbox[0])[1], transformer.transform(bbox[3], bbox[1])[1] east, west = transformer.transform(bbox[2], bbox[0])[0], transformer.transform(bbox[3], bbox[1])[0] bbox = (west, south, east, north) G = ox.graph_from_bbox(bbox,network_type='all') G_proj = ox.project_graph(G) edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True, fill_edge_geometry=True) edges_proj = edges.to_crs(epsg=user_crs) edges_proj = edges_proj[['osmid','name', 'geometry']] edges_proj = edges_proj[edges_proj['name'].notnull()] edges_proj['name'] = edges_proj['name'].apply( lambda x: x[0] if isinstance(x, list) and len(x) > 0 else x) OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson") edges_proj.to_file(OSM_PATH, driver="GeoJSON") yield "Done OSM extraction." def fuzzyMatchHist(hist_dic, hist_th, tile_dict): # Convert threshold to numeric hist_th = int(hist_th) # === Load Data === hist_df = pd.read_csv(hist_dic,header=None, names=["hist_name"]) OCR_PATH = os.path.join(OUTPUT_DIR, "ocr.csv") names_df = pd.read_csv( OCR_PATH, names=['blob_id', 'pred_text'], dtype={"blob_id": "int64", "pred_text": "string"} ) historic_names = hist_df["hist_name"].dropna().astype(str).tolist() # === Fuzzy Match === results = [] for _, row in names_df.iterrows(): ocr_name = row["pred_text"] if pd.isna(ocr_name): continue best_match, best_score, _ = process.extractOne( ocr_name, historic_names, scorer=fuzz.token_sort_ratio ) results.append({ "blob_id": row["blob_id"], "ocr_name": ocr_name, "best_hist_match": best_match, "match_score": best_score }) results_df = pd.DataFrame(results) tile = tile_dict["tile_path"] match = re.search(r'\d+', tile) tile_number=int(match.group()) # === Save all results === all_results_path = os.path.join(OUTPUT_DIR, f"historic_matches_tile{tile_number}.csv") results_df.to_csv(all_results_path, index=False) # === Filter for manual annotation === manual_df = results_df[results_df["match_score"] >= hist_th] for blob_id in manual_df['blob_id']: # original blob orig_path = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}.png") if os.path.exists(orig_path): os.remove(orig_path) # marginalized blob margin_path = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}_margin.png") if os.path.exists(margin_path): os.remove(margin_path) yield "Historic fuzzy matching complete." yield [all_results_path] def fuzzyMatch(score_th, tile_dict): COORD_PATH = os.path.join(OUTPUT_DIR, "centroids.csv") OCR_PATH = os.path.join(OUTPUT_DIR, "ocr.csv") coords_df = pd.read_csv(COORD_PATH) names_df = pd.read_csv( OCR_PATH, names=['blob_id', 'pred_text'], dtype={"blob_id": "int64", "pred_text": "string"} ) merged_df = coords_df.merge(names_df, on="blob_id") gdf = gpd.GeoDataFrame( merged_df, geometry=gpd.points_from_xy(merged_df.x, merged_df.y), crs="EPSG:3857" ) # Add lat lon to the blobs # Reproject temporarily to WGS84 for coordinates gdf_ll = gdf.to_crs(epsg=4326) # Add longitude/latitude columns to the original gdf gdf['lon'] = gdf_ll.geometry.x gdf['lat'] = gdf_ll.geometry.y OSM_PATH = os.path.join(OUTPUT_DIR, "osm_extract.geojson") osm_gdf = gpd.read_file(OSM_PATH, dtype={"name": "str"}) osm_gdf["name"] = osm_gdf["name"].str.replace("strasse", "", case=False, regex=False) # Build spatial index for fast nearest lookup osm_sindex = osm_gdf.sindex yield "Process OSM candidates..." results = [] for _, row in gdf.iterrows(): geom = row.geometry if isinstance(geom, gpd.GeoSeries): geom = geom.iloc[0] # Levenshtein-based fuzzy matching match = best_street_match(geom, row['pred_text'], osm_gdf, max_distance=100) # Closest OSM street geometrically nearest_idx, nearest_dist = osm_sindex.nearest(geom, return_all=False,return_distance=True) #closest_geom = osm_gdf.geometry.iloc[nearest_idx[1]] closest_name = osm_gdf.name.iloc[nearest_idx[1]].values[0] results.append({ "blob_id": row.blob_id, "lon": row.lon, "lat": row.lat, "blob_name": row.pred_text, "best_osm_match": match[0] if match else None, "osm_match_score": match[1] if match else 0, "closest_osm_street": closest_name, "closest_osm_distance_m": nearest_dist[0] }) results_df = pd.DataFrame(results) # Save results tile = tile_dict["tile_path"] match = re.search(r'\d+', tile) tile_number = int(match.group()) RES_PATH = os.path.join(OUTPUT_DIR, f"street_matches_tile{tile_number}.csv") results_df.to_csv(RES_PATH, index=False) # Export OSM layer as CSV osm_gdf = osm_gdf.to_crs(epsg=4326) OSM_CSV_PATH = os.path.join(OUTPUT_DIR, f"osm_extract_tile{tile_number}.csv") osm_export_df = osm_gdf[["name", "geometry"]].copy() osm_export_df["geometry"] = osm_export_df["geometry"].apply(lambda g: g.wkt) osm_export_df.to_csv(OSM_CSV_PATH, index=False) # Remove blobs above score threshold manual_df = results_df[results_df['osm_match_score'] >= int(score_th)] for blob_id in manual_df['blob_id']: orig_path = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}.png") if os.path.exists(orig_path): os.remove(orig_path) margin_path = os.path.join(OUTPUT_DIR, "blobs", f"{blob_id}_margin.png") if os.path.exists(margin_path): os.remove(margin_path) yield [RES_PATH, OSM_CSV_PATH]