Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import os | |
| import base64 | |
| import io | |
| import logging # For better logging | |
| # Import specific handlers and formatter | |
| from logging.handlers import RotatingFileHandler | |
| import traceback # For detailed exception logging | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS # To handle Cross-Origin requests from your frontend | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import yaml | |
| from torchvision import transforms | |
| from transformers import SegformerForSemanticSegmentation | |
| from omegaconf import OmegaConf # Import OmegaConf itself | |
| import torch.nn.functional as F | |
| from werkzeug.utils import secure_filename # For safer filenames | |
| # --- Configuration --- | |
| # Use absolute paths for robustness | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running | |
| # >>> Point this to your actual config file <<< | |
| CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir | |
| # >>> Point this to your actual checkpoint file <<< | |
| CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt" | |
| UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads') | |
| RESULT_FOLDER = os.path.join(BASE_DIR, 'results') | |
| LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'} | |
| # --- Logging Setup --- | |
| # Clear existing handlers from the root logger to avoid duplicates on reload | |
| logging.getLogger().handlers.clear() | |
| # Create formatter | |
| log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s') | |
| # Create Console Handler | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(log_formatter) | |
| # Create File Handler (using RotatingFileHandler for log rotation) | |
| file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3) | |
| file_handler.setFormatter(log_formatter) | |
| # Get the root logger and add handlers | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG) | |
| logger.addHandler(console_handler) | |
| logger.addHandler(file_handler) | |
| # --- Ensure upload and result directories exist --- | |
| try: | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(RESULT_FOLDER, exist_ok=True) | |
| logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}") | |
| except OSError as e: | |
| logger.error(f"Error creating directories: {e}") | |
| exit(1) # Exit if we can't create essential folders | |
| # --- Load Config --- | |
| config = None | |
| try: | |
| # Load the YAML file using OmegaConf | |
| config = OmegaConf.load(CONFIG_PATH) | |
| # Note: We don't need OmegaConf.create() if loading directly from file | |
| logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}") | |
| # Log some key values to confirm loading | |
| logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}") | |
| except FileNotFoundError: | |
| logger.error(f"Configuration file not found: {CONFIG_PATH}") | |
| exit(1) | |
| except Exception as e: # Catch broader errors during loading/parsing | |
| logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}") | |
| logger.error(traceback.format_exc()) | |
| exit(1) | |
| # --- Model Definition --- | |
| class InferenceModel(torch.nn.Module): | |
| def __init__(self, model_config): # Use local name 'model_config' | |
| super().__init__() | |
| try: | |
| # Access config values needed for model init | |
| model_name = model_config.training.model_name | |
| num_classes = model_config.data.num_classes | |
| logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}") | |
| self.model = SegformerForSemanticSegmentation.from_pretrained( | |
| model_name, | |
| num_labels=num_classes, | |
| ignore_mismatched_sizes=True # Important if fine-tuning head size differs | |
| ) | |
| logger.info("Segformer model part initialized.") | |
| except AttributeError as ae: | |
| logger.error(f"Config error during model init: Missing key? {ae}") | |
| logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}") | |
| raise # Re-raise error to stop execution | |
| except Exception as e: | |
| logger.error(f"Error initializing Segformer model from Hugging Face: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise # Re-raise error to stop execution | |
| def forward(self, x): | |
| # Expects pixel_values as input | |
| outputs = self.model(pixel_values=x, return_dict=True) | |
| # Upsample logits to original input size | |
| logits = F.interpolate( | |
| outputs.logits, | |
| size=x.shape[-2:], # Get H, W from input tensor x | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| return logits | |
| # --- Utility Functions --- | |
| def num_to_rgb(num_arr, color_map_dict): | |
| """Converts a label mask (numpy array) to an RGB color mask.""" | |
| single_layer = np.squeeze(num_arr) | |
| output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros | |
| # Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]} | |
| if not isinstance(color_map_dict, dict): | |
| logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.") | |
| return np.float32(output) / 255.0 # Return black float image | |
| unique_labels = np.unique(single_layer) | |
| for k in unique_labels: | |
| label_key = int(k) # Ensure key is standard int for lookup | |
| if label_key in color_map_dict: | |
| # Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints) | |
| color = color_map_dict[label_key] | |
| if isinstance(color, (list, tuple)) and len(color) == 3: | |
| output[single_layer == k] = color | |
| else: | |
| logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.") | |
| else: | |
| if label_key != 0: # Often 0 is background, might not be in map | |
| logger.warning(f"Label Key {label_key} found in mask but not in provided color map.") | |
| # Default color (e.g., black) is already set by np.zeros | |
| return np.float32(output) / 255.0 # Return float32 RGB image [0, 1] | |
| def denormalize(tensor, mean, std): | |
| """Denormalizes a torch tensor (CHW format).""" | |
| # Expects standard Python lists/tuples for mean/std | |
| if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)): | |
| logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.") | |
| return None | |
| # Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform) | |
| if tensor.dim() != 4: # B C H W | |
| logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).") | |
| # Attempt to add batch dim if it's 3D (CHW) | |
| if tensor.dim() == 3: | |
| logger.warning("Denormalize received 3D tensor, adding batch dimension.") | |
| tensor = tensor.unsqueeze(0) | |
| else: | |
| return None # Cannot handle other dims | |
| num_channels = tensor.shape[1] # Channel dimension | |
| if len(mean) != num_channels or len(std) != num_channels: | |
| logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})") | |
| return None | |
| # Clone to avoid modifying original tensor | |
| tensor = tensor.clone().cpu() # Work on CPU copy | |
| # Denormalize each channel | |
| for c in range(num_channels): | |
| tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch | |
| # Clamp values, remove batch dimension, permute to HWC for display/saving | |
| # Assumes we are processing one image at a time here for inference result | |
| denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0) | |
| return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1]) | |
| # --- Load Model (Corrected Version) --- | |
| def load_trained_model(checkpoint_path, model_config): | |
| """Loads the trained model from a checkpoint, handling potential key mismatches.""" | |
| try: | |
| model_instance = InferenceModel(model_config) # Create model structure | |
| logger.info(f"Attempting to load checkpoint from: {checkpoint_path}") | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) | |
| logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}") | |
| # Extract the state dictionary - flexible based on common saving patterns | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| logger.info("Using 'state_dict' key from checkpoint.") | |
| elif isinstance(checkpoint, dict): | |
| # Assume the dict *is* the state_dict if 'state_dict' key is absent | |
| state_dict = checkpoint | |
| logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).") | |
| else: | |
| # Could be the model itself was saved directly (less common with frameworks) | |
| logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}") | |
| # This path might need adjustment based on how the model was saved if not a state_dict | |
| try: | |
| model_instance.load_state_dict(checkpoint) # Try loading directly | |
| logger.info("Loaded state_dict directly from checkpoint object.") | |
| model_instance.eval() | |
| return model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to load state_dict directly from checkpoint object: {e}") | |
| return None # Failed direct load | |
| # --- Key Prefix Correction Logic --- | |
| target_keys = set(model_instance.state_dict().keys()) | |
| loaded_keys = set(state_dict.keys()) | |
| if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty | |
| first_loaded_key = next(iter(loaded_keys), None) | |
| first_target_key = next(iter(target_keys), None) | |
| corrected_state_dict = {} | |
| prefix_added = False | |
| # Check if prefix 'model.' needs to be ADDED to loaded keys | |
| if first_loaded_key and not first_loaded_key.startswith('model.') and \ | |
| first_target_key and first_target_key.startswith('model.'): | |
| logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.") | |
| prefix_added = True | |
| keys_not_prefixed_properly = [] | |
| for k, v in state_dict.items(): | |
| new_key = f"model.{k}" | |
| if new_key in target_keys: corrected_state_dict[new_key] = v | |
| else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted | |
| if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}") | |
| logger.info("Finished attempting prefix addition.") | |
| # Check if prefix 'model.' needs to be REMOVED from loaded keys | |
| elif first_loaded_key and first_loaded_key.startswith('model.') and \ | |
| first_target_key and not first_target_key.startswith('model.'): | |
| logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.") | |
| prefix_added = False # Indicate we removed prefix, not added | |
| keys_not_stripped_properly = [] | |
| for k, v in state_dict.items(): | |
| if k.startswith('model.'): | |
| new_key = k.partition('model.')[2] # Get part after 'model.' | |
| if new_key in target_keys: corrected_state_dict[new_key] = v | |
| else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted | |
| else: | |
| # Keep keys that didn't have prefix anyway | |
| corrected_state_dict[k] = v | |
| if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}") | |
| logger.info("Finished attempting prefix removal.") | |
| else: | |
| logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.") | |
| corrected_state_dict = state_dict # Use the original dict | |
| # --- Load the State Dictionary --- | |
| logger.info("Attempting to load state_dict with strict=False for checking...") | |
| missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False) | |
| # Report detailed findings | |
| final_msg = [] | |
| is_load_successful = True | |
| if missing_keys: | |
| final_msg.append(f"MISSING keys in checkpoint: {missing_keys}") | |
| logger.error("CRITICAL FAILURE: Model is missing required keys.") | |
| is_load_successful = False | |
| if unexpected_keys: | |
| final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}") | |
| # Decide if unexpected keys are acceptable | |
| acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')] | |
| unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')] | |
| if unacceptable_unexpected: | |
| logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}") | |
| is_load_successful = False | |
| elif acceptable_unexpected: | |
| logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}") | |
| if not is_load_successful: | |
| logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}") | |
| return None # Failed to load properly | |
| logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}") | |
| model_instance.eval() # Set to evaluation mode | |
| logger.info(f"Model loading process complete for {checkpoint_path}") | |
| return model_instance | |
| except FileNotFoundError as fnf_error: | |
| logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error during model loading: {e}") | |
| logger.error(traceback.format_exc()) # Log full traceback | |
| return None | |
| # --- Determine device & Load Model Globally --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| # Load the model using the global config object | |
| model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config | |
| if model is None: | |
| logger.critical("CRITICAL: Failed to load model. Application cannot continue.") | |
| exit(1) # Critical error, stop the application | |
| model.to(device) # Move model to the appropriate device | |
| # --- Inference Pipeline (Corrected Config Handling) --- | |
| def run_inference_on_bytes(image_bytes, inference_model, model_config, device): | |
| """Runs inference on image bytes, returns denormalized image, color mask, and overlay.""" | |
| try: | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img is None: logger.error("Failed cv2.imdecode."); return None, None, None | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| logger.debug("Image decoded and converted to RGB.") | |
| # --- Preprocessing (with config conversion) --- | |
| # Check necessary config attributes exist before conversion attempt | |
| required_data_keys = ['image_size', 'mean', 'std', 'num_classes'] | |
| for key in required_data_keys: | |
| if not OmegaConf.select(model_config, f'data.{key}', default=None): | |
| logger.error(f"Config missing required data field: data.{key}") | |
| return None, None, None | |
| if not OmegaConf.select(model_config, 'id2color', default=None): | |
| logger.error("Config missing required field: id2color") | |
| return None, None, None | |
| if not OmegaConf.select(model_config, 'training.model_name', default=None): | |
| logger.error("Config missing required field: training.model_name") | |
| return None, None, None | |
| try: | |
| # Convert OmegaConf structures to standard Python types using OmegaConf.to_container | |
| # resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields | |
| img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True)) | |
| mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True)) | |
| std = list(OmegaConf.to_container(model_config.data.std, resolve=True)) | |
| # Ensure keys in id2color are standard integers | |
| id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()} | |
| num_classes = int(model_config.data.num_classes) # Ensure int | |
| logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}") | |
| # Basic validation after conversion | |
| if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.") | |
| if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels | |
| if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).") | |
| except Exception as e: | |
| logger.error(f"Error processing/converting configuration values: {e}") | |
| logger.error(traceback.format_exc()) | |
| return None, None, None | |
| # Define the image transformation pipeline | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch | |
| transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model | |
| transforms.Normalize(mean=mean, std=std) # Use converted lists | |
| ]) | |
| logger.debug(f"Image transform applied for size {img_size}.") | |
| input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device | |
| logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W] | |
| # --- Run Prediction --- | |
| with torch.no_grad(): | |
| logits = inference_model(input_tensor) # Expect [B, C, H, W] logits | |
| logger.debug(f"Logits received with shape: {logits.shape}") | |
| # Check logits shape again after potential upsampling in model forward | |
| if logits.dim() != 4 or logits.shape[1] != num_classes: | |
| logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.") | |
| return None, None, None | |
| # Argmax along class dimension (C), remove batch dim, move to CPU, convert type | |
| pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8 | |
| logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W] | |
| # --- Post-processing --- | |
| color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map | |
| if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None | |
| logger.debug("Color mask generated.") | |
| # Denormalize the *input tensor* for overlay display | |
| denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std | |
| if denorm_img is None: logger.error("denormalize failed."); return None, None, None | |
| logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1] | |
| # --- Create Overlay --- | |
| # Ensure shapes match before blending (resize color mask to match denorm_img) | |
| if denorm_img.shape[:2] != color_mask.shape[:2]: | |
| logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.") | |
| # Resize color_mask (HWC float32) to match denorm_img (HWC float32) | |
| color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks | |
| # Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma | |
| overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0) | |
| logger.debug("Overlay created using cv2.addWeighted.") | |
| # overlay is HWC, float32, [0, 1], RGB | |
| return denorm_img, color_mask, overlay | |
| except Exception as e: | |
| logger.error(f"Exception during inference pipeline for image: {e}") | |
| logger.error(traceback.format_exc()) | |
| return None, None, None | |
| # --- Flask App --- | |
| app = Flask(__name__) | |
| CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}} | |
| logger.info("Flask app created and CORS enabled.") | |
| def allowed_file(filename): | |
| """Checks if the filename has an allowed extension.""" | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| # --- API Endpoints --- | |
| def analyze_image(): | |
| """Receives Base64 image, runs inference, saves original and overlay.""" | |
| global model, config, device # Access global vars | |
| endpoint_log_prefix = "[POST /api/analyze]" | |
| logger.info(f"{endpoint_log_prefix} Received request.") | |
| # --- Basic Checks --- | |
| if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500 | |
| if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400 | |
| data = request.get_json() | |
| if not data or 'image' not in data or 'filename' not in data: | |
| logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}") | |
| return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400 | |
| base64_image_data = data['image']; original_filename = data['filename'] | |
| logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'") | |
| safe_original_filename = secure_filename(original_filename) # Sanitize | |
| if not safe_original_filename or not allowed_file(safe_original_filename): | |
| logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'") | |
| return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400 | |
| logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'") | |
| try: | |
| # --- Decode Base64 --- | |
| if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1) | |
| else: encoded = base64_image_data # Assume no header | |
| image_bytes = base64.b64decode(encoded) | |
| logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).") | |
| # --- Save Original Image --- | |
| original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename) | |
| try: | |
| with open(original_path, "wb") as f: f.write(image_bytes) | |
| logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'") | |
| except Exception as e: | |
| logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}") | |
| return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500 | |
| # --- Run Inference --- | |
| logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...") | |
| # Pass the global config object here | |
| denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device) | |
| if overlay is None: # Check if inference failed | |
| logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.") | |
| return jsonify({"success": False, "message": "Inference process failed on server"}), 500 | |
| logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.") | |
| # --- Save Overlay Image --- | |
| name_part, ext = os.path.splitext(safe_original_filename) | |
| # Create consistent overlay filename (crucial for toggle endpoint) | |
| overlay_filename = f"analyzed_{name_part}{ext}" | |
| overlay_path = os.path.join(RESULT_FOLDER, overlay_filename) | |
| logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'") | |
| # Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite | |
| try: | |
| overlay_to_save_uint8 = (overlay * 255).astype(np.uint8) | |
| overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR) | |
| save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr) | |
| if not save_success: | |
| raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}") | |
| logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'") | |
| except Exception as e: | |
| logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}") | |
| logger.error(traceback.format_exc()) | |
| return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500 | |
| # --- Success Response --- | |
| logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.") | |
| return jsonify({ | |
| "success": True, | |
| "message": "Analysis complete", | |
| # Optionally return relative paths for info, client mainly needs overlay_filename | |
| "paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)}, | |
| "overlay_filename": overlay_filename # Return the *exact* filename saved | |
| }), 200 | |
| except base64.binascii.Error as e: | |
| logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}") | |
| return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400 | |
| except Exception as e: | |
| logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}") | |
| logger.error(traceback.format_exc()) | |
| return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500 | |
| def get_analysis_path(): | |
| """Checks if the analyzed version of a given original filename exists.""" | |
| endpoint_log_prefix = "[GET /api/toggle-image]" | |
| logger.info(f"{endpoint_log_prefix} Received request.") | |
| logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}") | |
| logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args | |
| original_filename = request.args.get('filename') # Get filename from ?filename=... | |
| if not original_filename: | |
| logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.") | |
| return jsonify({"message": "Missing 'filename' query parameter"}), 400 | |
| logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'") | |
| safe_original_filename = secure_filename(original_filename) # Sanitize | |
| if not safe_original_filename: | |
| logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'") | |
| return jsonify({"message": "Invalid filename format"}), 400 | |
| logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'") | |
| # --- Construct Expected Overlay Path (MUST match /analyze logic) --- | |
| name_part, ext = os.path.splitext(safe_original_filename) | |
| expected_overlay_filename = f"analyzed_{name_part}{ext}" | |
| expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename) | |
| logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'") | |
| # --- Check if File Exists --- | |
| if os.path.exists(expected_overlay_path): | |
| logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'") | |
| # Return just the filename, client constructs the full /Result/ URL | |
| return jsonify({"filepath": expected_overlay_filename}), 200 | |
| else: | |
| # Explicitly log the path that was checked and not found | |
| logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'") | |
| # Return 404 Not Found status code | |
| return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404 | |
| def serve_result_image(filename): | |
| """Serves images from the RESULT_FOLDER.""" | |
| endpoint_log_prefix = "[GET /Result]" | |
| # Sanitize filename received in URL path for security | |
| safe_filename = secure_filename(filename) | |
| if safe_filename != filename: | |
| # Log if the requested filename was changed by sanitization | |
| logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.") | |
| logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'") | |
| try: | |
| # Use Flask's send_from_directory - safer than manual path joining | |
| # as_attachment=False means display in browser if possible | |
| return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False) | |
| except FileNotFoundError: | |
| # Log the specific file that was not found | |
| logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'") | |
| # Return 404 Not Found | |
| return jsonify({"message": "Requested analysis image not found"}), 404 | |
| except Exception as e: | |
| # Catch other potential errors (e.g., permission issues) | |
| logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}") | |
| logger.error(traceback.format_exc()) | |
| # Return 500 Internal Server Error | |
| return jsonify({"message": "Error serving analysis image"}), 500 | |
| # --- Main Execution --- | |
| if __name__ == '__main__': | |
| # Ensure model loaded successfully before starting server | |
| if model: | |
| logger.info("Model loaded successfully. Starting Flask development server...") | |
| # Use debug=True for development (auto-reload, debugger) | |
| # Use debug=False for production! | |
| # host='0.0.0.0' makes it accessible on the network | |
| app.run(host='0.0.0.0', port=7860, debug=True) | |
| else: | |
| # This message should appear if load_trained_model returned None | |
| logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.") | |
| # Exit code 1 indicates an error | |
| exit(1) |