import gradio as gr import torch from PIL import Image import numpy as np # --- 1. Load Custom Model Utilities --- # NOTE: These imports MUST match the files you copied from the GitHub repo. # Example imports - adjust these if the model files are deeper in subfolders! try: from mmseg.apis import init_segmentor, inference_segmentor # Core MMSeg functions from mmseg.datasets import build_dataloader, build_dataset # Utilities # You might also need to copy config files, e.g., to 'configs/relem/' except ImportError: print("MMSegmentation utilities not found. Ensure files were copied correctly.") # --- 2. CONFIGURATION --- # Define the paths for the files you placed in the repository WEIGHTS_PATH = "R50_ReLeM.pth" CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" # Replace with actual config file from the repo # --- 3. Model Loading Function --- @torch.no_grad() def load_relem_model(): """Initializes the segmentation model and loads the pre-trained weights.""" try: # 1. Initialize the segmentor using MMSegmentation's utility # This requires the config file and the checkpoint path model = init_segmentor( CONFIG_FILE, checkpoint=WEIGHTS_PATH, device='cuda:0' if torch.cuda.is_available() else 'cpu' ) model.eval() print("ReLeM Model loaded successfully!") return model except Exception as e: print(f"Error loading model: {e}") # Return a flag if loading fails return None # Load the model once when the Space starts RELEM_MODEL = load_relem_model() # --- 4. Inference Function for Gradio --- def segment_food(input_image: Image.Image): """Takes a PIL Image and returns a segmentation mask image.""" if RELEM_MODEL is None: return "Error: Model failed to load. Check logs for details." try: # Use MMSegmentation's inference pipeline # The input is usually a filepath, so we need to save and then load # 1. Save input image temporarily temp_path = "/tmp/input_img.png" input_image.save(temp_path) # 2. Run Inference result = inference_segmentor(RELEM_MODEL, temp_path) # 3. Post-process the result (usually a numpy array) into a color mask image # The result is a segmentation map (array of class IDs). # We use a simple utility to convert the ID map to a visible color mask. seg_mask_array = result[0] color_mask = Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L") # NOTE: Full color mapping requires the class labels/palette, which you must also copy from the repo. return color_mask except Exception as e: return f"Inference failed: {e}" # --- 5. GRADIO INTERFACE --- gr.Interface( fn=segment_food, inputs=gr.Image(type="pil", label="Upload Food Image"), outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"), title="ReLeM (FoodSeg103) Segmentation Demo", description="Custom deployment of the ReLeM PyTorch model. **NOTE:** Model loading requires the full code/config structure from the GitHub repo.", allow_flagging="never" ).launch()