aidentaldocker / app.py
lukiod's picture
tt
8d7600a
# -*- coding: utf-8 -*-
import os
import base64
import io
import logging # For better logging
# Import specific handlers and formatter
from logging.handlers import RotatingFileHandler
import traceback # For detailed exception logging
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS # To handle Cross-Origin requests from your frontend
import torch
import cv2
import numpy as np
import yaml
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
from omegaconf import OmegaConf # Import OmegaConf itself
import torch.nn.functional as F
from werkzeug.utils import secure_filename # For safer filenames
# --- Configuration ---
# Use absolute paths for robustness
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running
# >>> Point this to your actual config file <<<
CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir
# >>> Point this to your actual checkpoint file <<<
CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt"
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads')
RESULT_FOLDER = os.path.join(BASE_DIR, 'results')
LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'}
# --- Logging Setup ---
# Clear existing handlers from the root logger to avoid duplicates on reload
logging.getLogger().handlers.clear()
# Create formatter
log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
# Create Console Handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
# Create File Handler (using RotatingFileHandler for log rotation)
file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3)
file_handler.setFormatter(log_formatter)
# Get the root logger and add handlers
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
# --- Ensure upload and result directories exist ---
try:
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}")
except OSError as e:
logger.error(f"Error creating directories: {e}")
exit(1) # Exit if we can't create essential folders
# --- Load Config ---
config = None
try:
# Load the YAML file using OmegaConf
config = OmegaConf.load(CONFIG_PATH)
# Note: We don't need OmegaConf.create() if loading directly from file
logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}")
# Log some key values to confirm loading
logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}")
except FileNotFoundError:
logger.error(f"Configuration file not found: {CONFIG_PATH}")
exit(1)
except Exception as e: # Catch broader errors during loading/parsing
logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}")
logger.error(traceback.format_exc())
exit(1)
# --- Model Definition ---
class InferenceModel(torch.nn.Module):
def __init__(self, model_config): # Use local name 'model_config'
super().__init__()
try:
# Access config values needed for model init
model_name = model_config.training.model_name
num_classes = model_config.data.num_classes
logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}")
self.model = SegformerForSemanticSegmentation.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Important if fine-tuning head size differs
)
logger.info("Segformer model part initialized.")
except AttributeError as ae:
logger.error(f"Config error during model init: Missing key? {ae}")
logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}")
raise # Re-raise error to stop execution
except Exception as e:
logger.error(f"Error initializing Segformer model from Hugging Face: {e}")
logger.error(traceback.format_exc())
raise # Re-raise error to stop execution
def forward(self, x):
# Expects pixel_values as input
outputs = self.model(pixel_values=x, return_dict=True)
# Upsample logits to original input size
logits = F.interpolate(
outputs.logits,
size=x.shape[-2:], # Get H, W from input tensor x
mode="bilinear",
align_corners=False
)
return logits
# --- Utility Functions ---
def num_to_rgb(num_arr, color_map_dict):
"""Converts a label mask (numpy array) to an RGB color mask."""
single_layer = np.squeeze(num_arr)
output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros
# Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]}
if not isinstance(color_map_dict, dict):
logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.")
return np.float32(output) / 255.0 # Return black float image
unique_labels = np.unique(single_layer)
for k in unique_labels:
label_key = int(k) # Ensure key is standard int for lookup
if label_key in color_map_dict:
# Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints)
color = color_map_dict[label_key]
if isinstance(color, (list, tuple)) and len(color) == 3:
output[single_layer == k] = color
else:
logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.")
else:
if label_key != 0: # Often 0 is background, might not be in map
logger.warning(f"Label Key {label_key} found in mask but not in provided color map.")
# Default color (e.g., black) is already set by np.zeros
return np.float32(output) / 255.0 # Return float32 RGB image [0, 1]
def denormalize(tensor, mean, std):
"""Denormalizes a torch tensor (CHW format)."""
# Expects standard Python lists/tuples for mean/std
if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)):
logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.")
return None
# Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform)
if tensor.dim() != 4: # B C H W
logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).")
# Attempt to add batch dim if it's 3D (CHW)
if tensor.dim() == 3:
logger.warning("Denormalize received 3D tensor, adding batch dimension.")
tensor = tensor.unsqueeze(0)
else:
return None # Cannot handle other dims
num_channels = tensor.shape[1] # Channel dimension
if len(mean) != num_channels or len(std) != num_channels:
logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})")
return None
# Clone to avoid modifying original tensor
tensor = tensor.clone().cpu() # Work on CPU copy
# Denormalize each channel
for c in range(num_channels):
tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch
# Clamp values, remove batch dimension, permute to HWC for display/saving
# Assumes we are processing one image at a time here for inference result
denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0)
return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1])
# --- Load Model (Corrected Version) ---
def load_trained_model(checkpoint_path, model_config):
"""Loads the trained model from a checkpoint, handling potential key mismatches."""
try:
model_instance = InferenceModel(model_config) # Create model structure
logger.info(f"Attempting to load checkpoint from: {checkpoint_path}")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}")
# Extract the state dictionary - flexible based on common saving patterns
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
logger.info("Using 'state_dict' key from checkpoint.")
elif isinstance(checkpoint, dict):
# Assume the dict *is* the state_dict if 'state_dict' key is absent
state_dict = checkpoint
logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).")
else:
# Could be the model itself was saved directly (less common with frameworks)
logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}")
# This path might need adjustment based on how the model was saved if not a state_dict
try:
model_instance.load_state_dict(checkpoint) # Try loading directly
logger.info("Loaded state_dict directly from checkpoint object.")
model_instance.eval()
return model_instance
except Exception as e:
logger.error(f"Failed to load state_dict directly from checkpoint object: {e}")
return None # Failed direct load
# --- Key Prefix Correction Logic ---
target_keys = set(model_instance.state_dict().keys())
loaded_keys = set(state_dict.keys())
if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty
first_loaded_key = next(iter(loaded_keys), None)
first_target_key = next(iter(target_keys), None)
corrected_state_dict = {}
prefix_added = False
# Check if prefix 'model.' needs to be ADDED to loaded keys
if first_loaded_key and not first_loaded_key.startswith('model.') and \
first_target_key and first_target_key.startswith('model.'):
logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.")
prefix_added = True
keys_not_prefixed_properly = []
for k, v in state_dict.items():
new_key = f"model.{k}"
if new_key in target_keys: corrected_state_dict[new_key] = v
else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted
if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}")
logger.info("Finished attempting prefix addition.")
# Check if prefix 'model.' needs to be REMOVED from loaded keys
elif first_loaded_key and first_loaded_key.startswith('model.') and \
first_target_key and not first_target_key.startswith('model.'):
logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.")
prefix_added = False # Indicate we removed prefix, not added
keys_not_stripped_properly = []
for k, v in state_dict.items():
if k.startswith('model.'):
new_key = k.partition('model.')[2] # Get part after 'model.'
if new_key in target_keys: corrected_state_dict[new_key] = v
else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted
else:
# Keep keys that didn't have prefix anyway
corrected_state_dict[k] = v
if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}")
logger.info("Finished attempting prefix removal.")
else:
logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.")
corrected_state_dict = state_dict # Use the original dict
# --- Load the State Dictionary ---
logger.info("Attempting to load state_dict with strict=False for checking...")
missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False)
# Report detailed findings
final_msg = []
is_load_successful = True
if missing_keys:
final_msg.append(f"MISSING keys in checkpoint: {missing_keys}")
logger.error("CRITICAL FAILURE: Model is missing required keys.")
is_load_successful = False
if unexpected_keys:
final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}")
# Decide if unexpected keys are acceptable
acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')]
unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')]
if unacceptable_unexpected:
logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}")
is_load_successful = False
elif acceptable_unexpected:
logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}")
if not is_load_successful:
logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}")
return None # Failed to load properly
logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}")
model_instance.eval() # Set to evaluation mode
logger.info(f"Model loading process complete for {checkpoint_path}")
return model_instance
except FileNotFoundError as fnf_error:
logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message
return None
except Exception as e:
logger.error(f"Unexpected error during model loading: {e}")
logger.error(traceback.format_exc()) # Log full traceback
return None
# --- Determine device & Load Model Globally ---
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Load the model using the global config object
model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config
if model is None:
logger.critical("CRITICAL: Failed to load model. Application cannot continue.")
exit(1) # Critical error, stop the application
model.to(device) # Move model to the appropriate device
# --- Inference Pipeline (Corrected Config Handling) ---
def run_inference_on_bytes(image_bytes, inference_model, model_config, device):
"""Runs inference on image bytes, returns denormalized image, color mask, and overlay."""
try:
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None: logger.error("Failed cv2.imdecode."); return None, None, None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
logger.debug("Image decoded and converted to RGB.")
# --- Preprocessing (with config conversion) ---
# Check necessary config attributes exist before conversion attempt
required_data_keys = ['image_size', 'mean', 'std', 'num_classes']
for key in required_data_keys:
if not OmegaConf.select(model_config, f'data.{key}', default=None):
logger.error(f"Config missing required data field: data.{key}")
return None, None, None
if not OmegaConf.select(model_config, 'id2color', default=None):
logger.error("Config missing required field: id2color")
return None, None, None
if not OmegaConf.select(model_config, 'training.model_name', default=None):
logger.error("Config missing required field: training.model_name")
return None, None, None
try:
# Convert OmegaConf structures to standard Python types using OmegaConf.to_container
# resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields
img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True))
mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True))
std = list(OmegaConf.to_container(model_config.data.std, resolve=True))
# Ensure keys in id2color are standard integers
id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()}
num_classes = int(model_config.data.num_classes) # Ensure int
logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}")
# Basic validation after conversion
if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.")
if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels
if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).")
except Exception as e:
logger.error(f"Error processing/converting configuration values: {e}")
logger.error(traceback.format_exc())
return None, None, None
# Define the image transformation pipeline
transform = transforms.Compose([
transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model
transforms.Normalize(mean=mean, std=std) # Use converted lists
])
logger.debug(f"Image transform applied for size {img_size}.")
input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device
logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W]
# --- Run Prediction ---
with torch.no_grad():
logits = inference_model(input_tensor) # Expect [B, C, H, W] logits
logger.debug(f"Logits received with shape: {logits.shape}")
# Check logits shape again after potential upsampling in model forward
if logits.dim() != 4 or logits.shape[1] != num_classes:
logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.")
return None, None, None
# Argmax along class dimension (C), remove batch dim, move to CPU, convert type
pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8
logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W]
# --- Post-processing ---
color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map
if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None
logger.debug("Color mask generated.")
# Denormalize the *input tensor* for overlay display
denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std
if denorm_img is None: logger.error("denormalize failed."); return None, None, None
logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1]
# --- Create Overlay ---
# Ensure shapes match before blending (resize color mask to match denorm_img)
if denorm_img.shape[:2] != color_mask.shape[:2]:
logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.")
# Resize color_mask (HWC float32) to match denorm_img (HWC float32)
color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks
# Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma
overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0)
logger.debug("Overlay created using cv2.addWeighted.")
# overlay is HWC, float32, [0, 1], RGB
return denorm_img, color_mask, overlay
except Exception as e:
logger.error(f"Exception during inference pipeline for image: {e}")
logger.error(traceback.format_exc())
return None, None, None
# --- Flask App ---
app = Flask(__name__)
CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}}
logger.info("Flask app created and CORS enabled.")
def allowed_file(filename):
"""Checks if the filename has an allowed extension."""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# --- API Endpoints ---
@app.route('/api/analyze', methods=['POST'])
def analyze_image():
"""Receives Base64 image, runs inference, saves original and overlay."""
global model, config, device # Access global vars
endpoint_log_prefix = "[POST /api/analyze]"
logger.info(f"{endpoint_log_prefix} Received request.")
# --- Basic Checks ---
if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500
if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400
data = request.get_json()
if not data or 'image' not in data or 'filename' not in data:
logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}")
return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400
base64_image_data = data['image']; original_filename = data['filename']
logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'")
safe_original_filename = secure_filename(original_filename) # Sanitize
if not safe_original_filename or not allowed_file(safe_original_filename):
logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400
logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'")
try:
# --- Decode Base64 ---
if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1)
else: encoded = base64_image_data # Assume no header
image_bytes = base64.b64decode(encoded)
logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).")
# --- Save Original Image ---
original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename)
try:
with open(original_path, "wb") as f: f.write(image_bytes)
logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'")
except Exception as e:
logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}")
return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500
# --- Run Inference ---
logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...")
# Pass the global config object here
denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device)
if overlay is None: # Check if inference failed
logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.")
return jsonify({"success": False, "message": "Inference process failed on server"}), 500
logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.")
# --- Save Overlay Image ---
name_part, ext = os.path.splitext(safe_original_filename)
# Create consistent overlay filename (crucial for toggle endpoint)
overlay_filename = f"analyzed_{name_part}{ext}"
overlay_path = os.path.join(RESULT_FOLDER, overlay_filename)
logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'")
# Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite
try:
overlay_to_save_uint8 = (overlay * 255).astype(np.uint8)
overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR)
save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr)
if not save_success:
raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}")
logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'")
except Exception as e:
logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}")
logger.error(traceback.format_exc())
return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500
# --- Success Response ---
logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.")
return jsonify({
"success": True,
"message": "Analysis complete",
# Optionally return relative paths for info, client mainly needs overlay_filename
"paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)},
"overlay_filename": overlay_filename # Return the *exact* filename saved
}), 200
except base64.binascii.Error as e:
logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}")
return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400
except Exception as e:
logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}")
logger.error(traceback.format_exc())
return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500
@app.route('/api/toggle-image', methods=['GET'])
def get_analysis_path():
"""Checks if the analyzed version of a given original filename exists."""
endpoint_log_prefix = "[GET /api/toggle-image]"
logger.info(f"{endpoint_log_prefix} Received request.")
logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}")
logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args
original_filename = request.args.get('filename') # Get filename from ?filename=...
if not original_filename:
logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.")
return jsonify({"message": "Missing 'filename' query parameter"}), 400
logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'")
safe_original_filename = secure_filename(original_filename) # Sanitize
if not safe_original_filename:
logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
return jsonify({"message": "Invalid filename format"}), 400
logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'")
# --- Construct Expected Overlay Path (MUST match /analyze logic) ---
name_part, ext = os.path.splitext(safe_original_filename)
expected_overlay_filename = f"analyzed_{name_part}{ext}"
expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename)
logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'")
# --- Check if File Exists ---
if os.path.exists(expected_overlay_path):
logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'")
# Return just the filename, client constructs the full /Result/ URL
return jsonify({"filepath": expected_overlay_filename}), 200
else:
# Explicitly log the path that was checked and not found
logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'")
# Return 404 Not Found status code
return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404
@app.route('/Result/<filename>')
def serve_result_image(filename):
"""Serves images from the RESULT_FOLDER."""
endpoint_log_prefix = "[GET /Result]"
# Sanitize filename received in URL path for security
safe_filename = secure_filename(filename)
if safe_filename != filename:
# Log if the requested filename was changed by sanitization
logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.")
logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'")
try:
# Use Flask's send_from_directory - safer than manual path joining
# as_attachment=False means display in browser if possible
return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False)
except FileNotFoundError:
# Log the specific file that was not found
logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'")
# Return 404 Not Found
return jsonify({"message": "Requested analysis image not found"}), 404
except Exception as e:
# Catch other potential errors (e.g., permission issues)
logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}")
logger.error(traceback.format_exc())
# Return 500 Internal Server Error
return jsonify({"message": "Error serving analysis image"}), 500
# --- Main Execution ---
if __name__ == '__main__':
# Ensure model loaded successfully before starting server
if model:
logger.info("Model loaded successfully. Starting Flask development server...")
# Use debug=True for development (auto-reload, debugger)
# Use debug=False for production!
# host='0.0.0.0' makes it accessible on the network
app.run(host='0.0.0.0', port=7860, debug=True)
else:
# This message should appear if load_trained_model returned None
logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.")
# Exit code 1 indicates an error
exit(1)