Upload 13 files
Browse files- BackgroundEngine.py +432 -0
- FlowFacade.py +1 -23
- ResourceManager.py +0 -6
- TextProcessor.py +0 -7
- VideoEngine_optimized.py +0 -7
- app.py +5 -2
- css_style.py +102 -1
- image_blender.py +1117 -0
- mask_generator.py +648 -0
- requirements.txt +15 -5
- scene_templates.py +428 -0
- ui_manager.py +496 -234
BackgroundEngine.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import logging
|
| 6 |
+
import gc
|
| 7 |
+
import time
|
| 8 |
+
import os
|
| 9 |
+
from typing import Optional, Dict, Any, Callable
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
|
| 13 |
+
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| 14 |
+
import open_clip
|
| 15 |
+
from mask_generator import MaskGenerator
|
| 16 |
+
from image_blender import ImageBlender
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import spaces
|
| 20 |
+
SPACES_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
SPACES_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BackgroundEngine:
|
| 28 |
+
"""
|
| 29 |
+
Background generation engine for VividFlow.
|
| 30 |
+
|
| 31 |
+
Integrates SDXL pipeline, OpenCLIP analysis, mask generation,
|
| 32 |
+
and advanced image blending.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, device: str = "auto"):
|
| 36 |
+
self.device = self._setup_device(device)
|
| 37 |
+
self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 38 |
+
self.clip_model_name = "ViT-B-32"
|
| 39 |
+
self.clip_pretrained = "openai"
|
| 40 |
+
|
| 41 |
+
self.pipeline = None
|
| 42 |
+
self.clip_model = None
|
| 43 |
+
self.clip_preprocess = None
|
| 44 |
+
self.clip_tokenizer = None
|
| 45 |
+
self.is_initialized = False
|
| 46 |
+
|
| 47 |
+
self.max_image_size = 1024
|
| 48 |
+
self.default_steps = 25
|
| 49 |
+
self.use_fp16 = True
|
| 50 |
+
|
| 51 |
+
self.mask_generator = MaskGenerator(self.max_image_size)
|
| 52 |
+
self.image_blender = ImageBlender()
|
| 53 |
+
|
| 54 |
+
logger.info(f"BackgroundEngine initialized on {self.device}")
|
| 55 |
+
|
| 56 |
+
def _setup_device(self, device: str) -> str:
|
| 57 |
+
"""Setup computation device (ZeroGPU compatible)"""
|
| 58 |
+
if os.getenv('SPACE_ID') is not None:
|
| 59 |
+
return "cpu"
|
| 60 |
+
|
| 61 |
+
if device == "auto":
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
return "cuda"
|
| 64 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 65 |
+
return "mps"
|
| 66 |
+
return "cpu"
|
| 67 |
+
return device
|
| 68 |
+
|
| 69 |
+
def _memory_cleanup(self):
|
| 70 |
+
"""Memory cleanup"""
|
| 71 |
+
for _ in range(3):
|
| 72 |
+
gc.collect()
|
| 73 |
+
|
| 74 |
+
is_spaces = os.getenv('SPACE_ID') is not None
|
| 75 |
+
if not is_spaces and torch.cuda.is_available():
|
| 76 |
+
torch.cuda.empty_cache()
|
| 77 |
+
|
| 78 |
+
def load_models(self, progress_callback: Optional[Callable] = None):
|
| 79 |
+
"""Load SDXL and OpenCLIP models"""
|
| 80 |
+
if self.is_initialized:
|
| 81 |
+
logger.info("Models already loaded")
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
logger.info("Loading background generation models...")
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
self._memory_cleanup()
|
| 88 |
+
|
| 89 |
+
# Detect actual device (in ZeroGPU, CUDA becomes available after @spaces.GPU allocation)
|
| 90 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 91 |
+
logger.info(f"Loading models to device: {actual_device}")
|
| 92 |
+
|
| 93 |
+
if progress_callback:
|
| 94 |
+
progress_callback("Loading OpenCLIP...", 20)
|
| 95 |
+
|
| 96 |
+
# Load OpenCLIP
|
| 97 |
+
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
|
| 98 |
+
self.clip_model_name,
|
| 99 |
+
pretrained=self.clip_pretrained,
|
| 100 |
+
device=actual_device
|
| 101 |
+
)
|
| 102 |
+
self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name)
|
| 103 |
+
self.clip_model.eval()
|
| 104 |
+
|
| 105 |
+
logger.info("OpenCLIP loaded")
|
| 106 |
+
|
| 107 |
+
if progress_callback:
|
| 108 |
+
progress_callback("Loading SDXL pipeline...", 60)
|
| 109 |
+
|
| 110 |
+
# Load SDXL
|
| 111 |
+
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 112 |
+
self.base_model_id,
|
| 113 |
+
torch_dtype=torch.float16 if self.use_fp16 else torch.float32,
|
| 114 |
+
use_safetensors=True,
|
| 115 |
+
variant="fp16" if self.use_fp16 else None
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# DPM solver for faster generation
|
| 119 |
+
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 120 |
+
self.pipeline.scheduler.config
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.pipeline = self.pipeline.to(actual_device)
|
| 124 |
+
|
| 125 |
+
if progress_callback:
|
| 126 |
+
progress_callback("Applying optimizations...", 90)
|
| 127 |
+
|
| 128 |
+
# Memory optimizations
|
| 129 |
+
try:
|
| 130 |
+
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 131 |
+
logger.info("xformers enabled")
|
| 132 |
+
except Exception:
|
| 133 |
+
try:
|
| 134 |
+
self.pipeline.enable_attention_slicing()
|
| 135 |
+
logger.info("Attention slicing enabled")
|
| 136 |
+
except Exception:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
if hasattr(self.pipeline, 'enable_vae_tiling'):
|
| 140 |
+
self.pipeline.enable_vae_tiling()
|
| 141 |
+
|
| 142 |
+
if hasattr(self.pipeline, 'enable_vae_slicing'):
|
| 143 |
+
self.pipeline.enable_vae_slicing()
|
| 144 |
+
|
| 145 |
+
self.pipeline.unet.eval()
|
| 146 |
+
if hasattr(self.pipeline, 'vae'):
|
| 147 |
+
self.pipeline.vae.eval()
|
| 148 |
+
|
| 149 |
+
self.is_initialized = True
|
| 150 |
+
|
| 151 |
+
if progress_callback:
|
| 152 |
+
progress_callback("Models loaded!", 100)
|
| 153 |
+
|
| 154 |
+
logger.info("Background models loaded successfully")
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Model loading failed: {e}")
|
| 158 |
+
raise RuntimeError(f"Failed to load models: {str(e)}")
|
| 159 |
+
|
| 160 |
+
def analyze_image_with_clip(self, image: Image.Image) -> str:
|
| 161 |
+
"""Analyze image using OpenCLIP"""
|
| 162 |
+
if not self.clip_model:
|
| 163 |
+
return "Unknown"
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
# Use actual device
|
| 167 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 168 |
+
|
| 169 |
+
image_input = self.clip_preprocess(image).unsqueeze(0).to(actual_device)
|
| 170 |
+
|
| 171 |
+
categories = [
|
| 172 |
+
"a photo of a person",
|
| 173 |
+
"a photo of an animal",
|
| 174 |
+
"a photo of an object",
|
| 175 |
+
"a photo of nature",
|
| 176 |
+
"a photo of a building"
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
text_inputs = self.clip_tokenizer(categories).to(actual_device)
|
| 180 |
+
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
image_features = self.clip_model.encode_image(image_input)
|
| 183 |
+
text_features = self.clip_model.encode_text(text_inputs)
|
| 184 |
+
|
| 185 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 186 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 187 |
+
|
| 188 |
+
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
| 189 |
+
best_match_idx = similarity.argmax().item()
|
| 190 |
+
|
| 191 |
+
category = categories[best_match_idx].replace("a photo of ", "")
|
| 192 |
+
return category
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.error(f"CLIP analysis failed: {e}")
|
| 196 |
+
return "unknown"
|
| 197 |
+
|
| 198 |
+
def enhance_prompt(self, user_prompt: str, foreground_image: Image.Image) -> str:
|
| 199 |
+
"""Smart prompt enhancement based on image analysis"""
|
| 200 |
+
try:
|
| 201 |
+
img_array = np.array(foreground_image.convert('RGB'))
|
| 202 |
+
|
| 203 |
+
# Analyze color temperature
|
| 204 |
+
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 205 |
+
avg_b = np.mean(lab[:, :, 2])
|
| 206 |
+
is_warm = avg_b > 128
|
| 207 |
+
|
| 208 |
+
# Analyze brightness
|
| 209 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 210 |
+
avg_brightness = np.mean(gray)
|
| 211 |
+
is_bright = avg_brightness > 127
|
| 212 |
+
|
| 213 |
+
# Get subject type
|
| 214 |
+
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 215 |
+
subject_type = clip_analysis
|
| 216 |
+
|
| 217 |
+
# Build lighting descriptors
|
| 218 |
+
if is_warm and is_bright:
|
| 219 |
+
lighting = "warm golden hour lighting, soft natural light"
|
| 220 |
+
elif is_warm and not is_bright:
|
| 221 |
+
lighting = "warm ambient lighting, cozy atmosphere"
|
| 222 |
+
elif not is_warm and is_bright:
|
| 223 |
+
lighting = "bright daylight, clear sky lighting"
|
| 224 |
+
else:
|
| 225 |
+
lighting = "soft diffused light, gentle shadows"
|
| 226 |
+
|
| 227 |
+
# Build atmosphere based on subject
|
| 228 |
+
atmosphere_map = {
|
| 229 |
+
"person": "professional, elegant composition",
|
| 230 |
+
"animal": "natural, harmonious setting",
|
| 231 |
+
"object": "clean product photography style",
|
| 232 |
+
"nature": "scenic, peaceful atmosphere",
|
| 233 |
+
"building": "architectural, balanced composition"
|
| 234 |
+
}
|
| 235 |
+
atmosphere = atmosphere_map.get(subject_type, "balanced composition")
|
| 236 |
+
|
| 237 |
+
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 238 |
+
|
| 239 |
+
# Avoid conflicts
|
| 240 |
+
user_prompt_lower = user_prompt.lower()
|
| 241 |
+
if "sunset" in user_prompt_lower or "golden" in user_prompt_lower:
|
| 242 |
+
lighting = ""
|
| 243 |
+
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 244 |
+
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 245 |
+
|
| 246 |
+
# Combine
|
| 247 |
+
fragments = [user_prompt]
|
| 248 |
+
if lighting:
|
| 249 |
+
fragments.append(lighting)
|
| 250 |
+
fragments.append(atmosphere)
|
| 251 |
+
fragments.append(quality_modifiers)
|
| 252 |
+
|
| 253 |
+
enhanced_prompt = ", ".join(filter(None, fragments))
|
| 254 |
+
|
| 255 |
+
logger.debug(f"Enhanced: {enhanced_prompt[:80]}...")
|
| 256 |
+
return enhanced_prompt
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.warning(f"Prompt enhancement failed: {e}")
|
| 260 |
+
return f"{user_prompt}, high quality, detailed, photorealistic"
|
| 261 |
+
|
| 262 |
+
def _prepare_image(self, image: Image.Image) -> Image.Image:
|
| 263 |
+
"""Prepare image for processing"""
|
| 264 |
+
if image.mode != 'RGB':
|
| 265 |
+
image = image.convert('RGB')
|
| 266 |
+
|
| 267 |
+
width, height = image.size
|
| 268 |
+
max_size = self.max_image_size
|
| 269 |
+
|
| 270 |
+
if width > max_size or height > max_size:
|
| 271 |
+
ratio = min(max_size/width, max_size/height)
|
| 272 |
+
new_width = int(width * ratio)
|
| 273 |
+
new_height = int(height * ratio)
|
| 274 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 275 |
+
|
| 276 |
+
width, height = image.size
|
| 277 |
+
new_width = (width // 8) * 8
|
| 278 |
+
new_height = (height // 8) * 8
|
| 279 |
+
|
| 280 |
+
if new_width != width or new_height != height:
|
| 281 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 282 |
+
|
| 283 |
+
return image
|
| 284 |
+
|
| 285 |
+
def generate_background(
|
| 286 |
+
self,
|
| 287 |
+
prompt: str,
|
| 288 |
+
width: int,
|
| 289 |
+
height: int,
|
| 290 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 291 |
+
num_inference_steps: int = 25,
|
| 292 |
+
guidance_scale: float = 7.5
|
| 293 |
+
) -> Image.Image:
|
| 294 |
+
"""Generate background using SDXL"""
|
| 295 |
+
if not self.is_initialized:
|
| 296 |
+
raise RuntimeError("Models not loaded")
|
| 297 |
+
|
| 298 |
+
logger.info(f"Generating background: {prompt[:50]}...")
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
# Use actual device
|
| 302 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 303 |
+
|
| 304 |
+
with torch.inference_mode():
|
| 305 |
+
result = self.pipeline(
|
| 306 |
+
prompt=prompt,
|
| 307 |
+
negative_prompt=negative_prompt,
|
| 308 |
+
width=width,
|
| 309 |
+
height=height,
|
| 310 |
+
num_inference_steps=num_inference_steps,
|
| 311 |
+
guidance_scale=guidance_scale,
|
| 312 |
+
generator=torch.Generator(device=actual_device).manual_seed(42)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
generated_image = result.images[0]
|
| 316 |
+
logger.info("Background generation completed")
|
| 317 |
+
return generated_image
|
| 318 |
+
|
| 319 |
+
except torch.cuda.OutOfMemoryError:
|
| 320 |
+
logger.error("GPU memory exhausted")
|
| 321 |
+
self._memory_cleanup()
|
| 322 |
+
raise RuntimeError("GPU memory insufficient")
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.error(f"Generation failed: {e}")
|
| 326 |
+
raise RuntimeError(f"Generation failed: {str(e)}")
|
| 327 |
+
|
| 328 |
+
def generate_and_combine(
|
| 329 |
+
self,
|
| 330 |
+
original_image: Image.Image,
|
| 331 |
+
prompt: str,
|
| 332 |
+
combination_mode: str = "center",
|
| 333 |
+
focus_mode: str = "person",
|
| 334 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 335 |
+
num_inference_steps: int = 25,
|
| 336 |
+
guidance_scale: float = 7.5,
|
| 337 |
+
progress_callback: Optional[Callable] = None,
|
| 338 |
+
enable_prompt_enhancement: bool = True,
|
| 339 |
+
feather_radius: int = 0
|
| 340 |
+
) -> Dict[str, Any]:
|
| 341 |
+
"""
|
| 342 |
+
Generate background and combine with foreground.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
|
| 346 |
+
|
| 347 |
+
Returns dict with: combined_image, generated_scene, original_image, mask, success
|
| 348 |
+
"""
|
| 349 |
+
if not self.is_initialized:
|
| 350 |
+
raise RuntimeError("Models not loaded")
|
| 351 |
+
|
| 352 |
+
logger.info("Starting background generation and combination...")
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
if progress_callback:
|
| 356 |
+
progress_callback("Analyzing image...", 5)
|
| 357 |
+
|
| 358 |
+
# Prepare image
|
| 359 |
+
processed_original = self._prepare_image(original_image)
|
| 360 |
+
target_width, target_height = processed_original.size
|
| 361 |
+
|
| 362 |
+
if progress_callback:
|
| 363 |
+
progress_callback("Enhancing prompt...", 15)
|
| 364 |
+
|
| 365 |
+
# Enhance prompt
|
| 366 |
+
if enable_prompt_enhancement:
|
| 367 |
+
enhanced_prompt = self.enhance_prompt(prompt, processed_original)
|
| 368 |
+
else:
|
| 369 |
+
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
|
| 370 |
+
|
| 371 |
+
enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos"
|
| 372 |
+
|
| 373 |
+
if progress_callback:
|
| 374 |
+
progress_callback("Generating background...", 30)
|
| 375 |
+
|
| 376 |
+
# Generate background
|
| 377 |
+
generated_background = self.generate_background(
|
| 378 |
+
prompt=enhanced_prompt,
|
| 379 |
+
width=target_width,
|
| 380 |
+
height=target_height,
|
| 381 |
+
negative_prompt=enhanced_negative,
|
| 382 |
+
num_inference_steps=num_inference_steps,
|
| 383 |
+
guidance_scale=guidance_scale
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if progress_callback:
|
| 387 |
+
progress_callback("Creating mask...", 80)
|
| 388 |
+
|
| 389 |
+
# Generate mask
|
| 390 |
+
logger.info("Generating mask...")
|
| 391 |
+
combination_mask = self.mask_generator.create_gradient_based_mask(
|
| 392 |
+
processed_original,
|
| 393 |
+
combination_mode,
|
| 394 |
+
focus_mode
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if progress_callback:
|
| 398 |
+
progress_callback("Blending images...", 90)
|
| 399 |
+
|
| 400 |
+
# Blend images with feather_radius
|
| 401 |
+
logger.info("Blending images...")
|
| 402 |
+
combined_image = self.image_blender.simple_blend_images(
|
| 403 |
+
processed_original,
|
| 404 |
+
generated_background,
|
| 405 |
+
combination_mask,
|
| 406 |
+
feather_radius=feather_radius
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Cleanup
|
| 410 |
+
self._memory_cleanup()
|
| 411 |
+
|
| 412 |
+
if progress_callback:
|
| 413 |
+
progress_callback("Complete!", 100)
|
| 414 |
+
|
| 415 |
+
logger.info("Background generation completed successfully")
|
| 416 |
+
|
| 417 |
+
# Build result dict (always include mask for diagnostics)
|
| 418 |
+
return {
|
| 419 |
+
"combined_image": combined_image,
|
| 420 |
+
"generated_scene": generated_background,
|
| 421 |
+
"original_image": processed_original,
|
| 422 |
+
"mask": combination_mask,
|
| 423 |
+
"success": True
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.error(f"Generation failed: {e}")
|
| 428 |
+
self._memory_cleanup()
|
| 429 |
+
return {
|
| 430 |
+
"success": False,
|
| 431 |
+
"error": str(e)
|
| 432 |
+
}
|
FlowFacade.py
CHANGED
|
@@ -26,29 +26,7 @@ class FlowFacade:
|
|
| 26 |
self.text_processor = TextProcessor(resource_manager=None)
|
| 27 |
print("✓ DeltaFlow initialized")
|
| 28 |
|
| 29 |
-
|
| 30 |
-
num_inference_steps: int, enable_prompt_expansion: bool, **kwargs) -> int:
|
| 31 |
-
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
|
| 32 |
-
BASE_STEP_DURATION = 8
|
| 33 |
-
|
| 34 |
-
resized_image = self.video_engine.resize_image(image)
|
| 35 |
-
width, height = resized_image.width, resized_image.height
|
| 36 |
-
frames = self.video_engine.get_num_frames(duration_seconds)
|
| 37 |
-
|
| 38 |
-
factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
|
| 39 |
-
step_duration = BASE_STEP_DURATION * factor ** 1.5
|
| 40 |
-
total_duration = int(num_inference_steps) * step_duration
|
| 41 |
-
|
| 42 |
-
# Add overhead for first-time model loading
|
| 43 |
-
if not self.video_engine.is_loaded:
|
| 44 |
-
total_duration += 150
|
| 45 |
-
|
| 46 |
-
if enable_prompt_expansion:
|
| 47 |
-
total_duration += 40
|
| 48 |
-
|
| 49 |
-
return max(int(total_duration), 300)
|
| 50 |
-
|
| 51 |
-
@spaces.GPU(duration=_calculate_gpu_duration)
|
| 52 |
def generate_video_from_image(self, image: Image.Image, user_instruction: str,
|
| 53 |
duration_seconds: float = 3.0, num_inference_steps: int = 4,
|
| 54 |
guidance_scale: float = 1.0, guidance_scale_2: float = 1.0,
|
|
|
|
| 26 |
self.text_processor = TextProcessor(resource_manager=None)
|
| 27 |
print("✓ DeltaFlow initialized")
|
| 28 |
|
| 29 |
+
@spaces.GPU(duration=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def generate_video_from_image(self, image: Image.Image, user_instruction: str,
|
| 31 |
duration_seconds: float = 3.0, num_inference_steps: int = 4,
|
| 32 |
guidance_scale: float = 1.0, guidance_scale_2: float = 1.0,
|
ResourceManager.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
| 1 |
-
# %%writefile RescourceManager.py
|
| 2 |
-
"""
|
| 3 |
-
DeltaFlow - Resource Manager
|
| 4 |
-
Handles GPU memory allocation, deallocation, and cache management
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
import gc
|
| 8 |
import torch
|
| 9 |
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gc
|
| 2 |
import torch
|
| 3 |
from typing import Optional
|
TextProcessor.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
| 1 |
-
# %%writefile text_processor.py
|
| 2 |
-
"""
|
| 3 |
-
DeltaFlow - Text Processor
|
| 4 |
-
Handles semantic expansion using Qwen2.5-0.5B-Instruct
|
| 5 |
-
Converts brief instructions into detailed motion descriptions
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
import gc
|
| 9 |
import traceback
|
| 10 |
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gc
|
| 2 |
import traceback
|
| 3 |
from typing import Optional
|
VideoEngine_optimized.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
DeltaFlow - Video Engine (FP8 Optimized)
|
| 3 |
-
Ultra-fast Image-to-Video generation using Wan2.2-I2V-A14B
|
| 4 |
-
Features: Lightning LoRA + FP8 Quantization
|
| 5 |
-
~70-90s inference (vs 150s baseline)
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
import warnings
|
| 9 |
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 10 |
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import warnings
|
| 2 |
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 3 |
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
app.py
CHANGED
|
@@ -15,6 +15,7 @@ import ftfy
|
|
| 15 |
import sentencepiece
|
| 16 |
|
| 17 |
from FlowFacade import FlowFacade
|
|
|
|
| 18 |
from ui_manager import UIManager
|
| 19 |
|
| 20 |
|
|
@@ -124,11 +125,13 @@ def main():
|
|
| 124 |
|
| 125 |
try:
|
| 126 |
facade = FlowFacade()
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
is_colab = 'google.colab' in sys.modules
|
| 129 |
|
| 130 |
print("✓ Ready")
|
| 131 |
-
|
| 132 |
share=is_colab,
|
| 133 |
server_name="0.0.0.0",
|
| 134 |
server_port=None,
|
|
|
|
| 15 |
import sentencepiece
|
| 16 |
|
| 17 |
from FlowFacade import FlowFacade
|
| 18 |
+
from BackgroundEngine import BackgroundEngine
|
| 19 |
from ui_manager import UIManager
|
| 20 |
|
| 21 |
|
|
|
|
| 125 |
|
| 126 |
try:
|
| 127 |
facade = FlowFacade()
|
| 128 |
+
background_engine = BackgroundEngine()
|
| 129 |
+
ui_manager = UIManager(facade, background_engine)
|
| 130 |
+
interface = ui_manager.create_interface()
|
| 131 |
is_colab = 'google.colab' in sys.modules
|
| 132 |
|
| 133 |
print("✓ Ready")
|
| 134 |
+
interface.launch(
|
| 135 |
share=is_colab,
|
| 136 |
server_name="0.0.0.0",
|
| 137 |
server_port=None,
|
css_style.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
DELTAFLOW_CSS = """
|
| 2 |
-
/*
|
|
|
|
|
|
|
|
|
|
| 3 |
:root {
|
| 4 |
--primary-bg: #f8f9fa;
|
| 5 |
--secondary-bg: #ffffff;
|
|
@@ -11,9 +14,12 @@ DELTAFLOW_CSS = """
|
|
| 11 |
--accent-hover: #4f46e5;
|
| 12 |
--success-color: #10b981;
|
| 13 |
--error-color: #ef4444;
|
|
|
|
| 14 |
--shadow-sm: 0 2px 8px rgba(0, 0, 0, 0.08);
|
| 15 |
--shadow-md: 0 4px 16px rgba(0, 0, 0, 0.12);
|
| 16 |
--shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.16);
|
|
|
|
|
|
|
| 17 |
}
|
| 18 |
|
| 19 |
/* Main Container */
|
|
@@ -276,4 +282,99 @@ video {
|
|
| 276 |
max-width: 1200px !important;
|
| 277 |
margin: 0 auto !important;
|
| 278 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
"""
|
|
|
|
| 1 |
DELTAFLOW_CSS = """
|
| 2 |
+
/* Import professional fonts */
|
| 3 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
| 4 |
+
|
| 5 |
+
/* Global Light Theme - Combined VividFlow & SceneWeaver */
|
| 6 |
:root {
|
| 7 |
--primary-bg: #f8f9fa;
|
| 8 |
--secondary-bg: #ffffff;
|
|
|
|
| 14 |
--accent-hover: #4f46e5;
|
| 15 |
--success-color: #10b981;
|
| 16 |
--error-color: #ef4444;
|
| 17 |
+
--warning-color: #f59e0b;
|
| 18 |
--shadow-sm: 0 2px 8px rgba(0, 0, 0, 0.08);
|
| 19 |
--shadow-md: 0 4px 16px rgba(0, 0, 0, 0.12);
|
| 20 |
--shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.16);
|
| 21 |
+
--radius-md: 8px;
|
| 22 |
+
--radius-lg: 12px;
|
| 23 |
}
|
| 24 |
|
| 25 |
/* Main Container */
|
|
|
|
| 282 |
max-width: 1200px !important;
|
| 283 |
margin: 0 auto !important;
|
| 284 |
}
|
| 285 |
+
|
| 286 |
+
/* ==== SceneWeaver Background Generation Styles ==== */
|
| 287 |
+
|
| 288 |
+
/* Feature Card - Background Generation Tab */
|
| 289 |
+
.feature-card {
|
| 290 |
+
background: var(--card-bg) !important;
|
| 291 |
+
border: 1px solid var(--border-color) !important;
|
| 292 |
+
border-radius: var(--radius-lg) !important;
|
| 293 |
+
padding: 1.5rem !important;
|
| 294 |
+
box-shadow: var(--shadow-md) !important;
|
| 295 |
+
overflow: visible !important;
|
| 296 |
+
transition: all 0.2s ease !important;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
.feature-card:hover {
|
| 300 |
+
border-color: var(--accent-color) !important;
|
| 301 |
+
box-shadow: var(--shadow-lg) !important;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/* Scene Template Dropdown */
|
| 305 |
+
.template-dropdown select,
|
| 306 |
+
.template-dropdown input {
|
| 307 |
+
font-size: 0.95rem !important;
|
| 308 |
+
padding: 10px 14px !important;
|
| 309 |
+
border-radius: var(--radius-md) !important;
|
| 310 |
+
border: 1px solid var(--border-color) !important;
|
| 311 |
+
background: var(--secondary-bg) !important;
|
| 312 |
+
transition: all 0.2s ease !important;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
.template-dropdown select:focus,
|
| 316 |
+
.template-dropdown input:focus {
|
| 317 |
+
border-color: var(--accent-color) !important;
|
| 318 |
+
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.15) !important;
|
| 319 |
+
outline: none !important;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
/* Results Gallery */
|
| 323 |
+
.result-gallery {
|
| 324 |
+
border-radius: var(--radius-lg) !important;
|
| 325 |
+
overflow: hidden !important;
|
| 326 |
+
border: 1px solid var(--border-color) !important;
|
| 327 |
+
box-shadow: var(--shadow-md) !important;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
/* Secondary Button (Download, Clear) */
|
| 331 |
+
.secondary-button {
|
| 332 |
+
background: var(--secondary-bg) !important;
|
| 333 |
+
color: var(--accent-color) !important;
|
| 334 |
+
border: 1.5px solid var(--accent-color) !important;
|
| 335 |
+
border-radius: var(--radius-md) !important;
|
| 336 |
+
padding: 12px 20px !important;
|
| 337 |
+
font-weight: 500 !important;
|
| 338 |
+
transition: all 0.2s ease !important;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.secondary-button:hover {
|
| 342 |
+
background: rgba(99, 102, 241, 0.1) !important;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
/* Dropdown positioning fix for Gradio 4.x/5.x */
|
| 346 |
+
.gradio-dropdown,
|
| 347 |
+
.gradio-dropdown > div {
|
| 348 |
+
position: relative !important;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
.gradio-dropdown ul,
|
| 352 |
+
.gradio-dropdown [role="listbox"] {
|
| 353 |
+
position: absolute !important;
|
| 354 |
+
z-index: 9999 !important;
|
| 355 |
+
left: 0 !important;
|
| 356 |
+
top: 100% !important;
|
| 357 |
+
width: 100% !important;
|
| 358 |
+
max-height: 300px !important;
|
| 359 |
+
overflow-y: auto !important;
|
| 360 |
+
background: var(--secondary-bg) !important;
|
| 361 |
+
border: 1px solid var(--border-color) !important;
|
| 362 |
+
border-radius: var(--radius-md) !important;
|
| 363 |
+
box-shadow: var(--shadow-lg) !important;
|
| 364 |
+
margin-top: 4px !important;
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
/* Status Panel */
|
| 368 |
+
.status-panel {
|
| 369 |
+
background: var(--secondary-bg) !important;
|
| 370 |
+
border: 1px solid var(--border-color) !important;
|
| 371 |
+
border-radius: var(--radius-md) !important;
|
| 372 |
+
padding: 12px 16px !important;
|
| 373 |
+
margin: 16px 0 !important;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
.status-ready {
|
| 377 |
+
color: var(--success-color) !important;
|
| 378 |
+
font-weight: 500 !important;
|
| 379 |
+
}
|
| 380 |
"""
|
image_blender.py
ADDED
|
@@ -0,0 +1,1117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import traceback
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, Any, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
logger.setLevel(logging.INFO)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ImageBlender:
|
| 13 |
+
"""
|
| 14 |
+
Advanced image blending with aggressive spill suppression and color replacement.
|
| 15 |
+
|
| 16 |
+
Supports two primary modes:
|
| 17 |
+
- Background generation: Foreground preservation with edge refinement
|
| 18 |
+
- Inpainting: Seamless blending with adaptive color correction
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
enable_multi_scale: Whether multi-scale edge refinement is enabled
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
|
| 25 |
+
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
|
| 26 |
+
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
|
| 27 |
+
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value for strong protection
|
| 28 |
+
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background detection
|
| 29 |
+
|
| 30 |
+
# Inpainting-specific parameters
|
| 31 |
+
INPAINT_FEATHER_SCALE = 1.2 # Scale factor for inpainting feathering
|
| 32 |
+
INPAINT_COLOR_BLEND_RADIUS = 10 # Radius for color adaptation zone
|
| 33 |
+
|
| 34 |
+
def __init__(self, enable_multi_scale: bool = True):
|
| 35 |
+
"""
|
| 36 |
+
Initialize ImageBlender.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
enable_multi_scale : bool
|
| 41 |
+
Whether to enable multi-scale edge refinement (default True)
|
| 42 |
+
"""
|
| 43 |
+
self.enable_multi_scale = enable_multi_scale
|
| 44 |
+
self._debug_info = {}
|
| 45 |
+
self._adaptive_strength_map = None
|
| 46 |
+
|
| 47 |
+
def _erode_mask_edges(
|
| 48 |
+
self,
|
| 49 |
+
mask_array: np.ndarray,
|
| 50 |
+
erosion_pixels: int = 2
|
| 51 |
+
) -> np.ndarray:
|
| 52 |
+
"""
|
| 53 |
+
Erode mask edges to remove contaminated boundary pixels.
|
| 54 |
+
|
| 55 |
+
This removes the outermost pixels of the foreground mask where
|
| 56 |
+
color contamination from the original background is most likely.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
mask_array: Input mask as numpy array (uint8, 0-255)
|
| 60 |
+
erosion_pixels: Number of pixels to erode (default 2)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Eroded mask array (uint8)
|
| 64 |
+
"""
|
| 65 |
+
if erosion_pixels <= 0:
|
| 66 |
+
return mask_array
|
| 67 |
+
|
| 68 |
+
# Use elliptical kernel for natural-looking erosion
|
| 69 |
+
kernel_size = max(2, erosion_pixels)
|
| 70 |
+
kernel = cv2.getStructuringElement(
|
| 71 |
+
cv2.MORPH_ELLIPSE,
|
| 72 |
+
(kernel_size, kernel_size)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Apply erosion
|
| 76 |
+
eroded = cv2.erode(mask_array, kernel, iterations=1)
|
| 77 |
+
|
| 78 |
+
# Slight blur to smooth the eroded edges
|
| 79 |
+
eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
|
| 80 |
+
|
| 81 |
+
logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
|
| 82 |
+
return eroded
|
| 83 |
+
|
| 84 |
+
def _binarize_edge_alpha(
|
| 85 |
+
self,
|
| 86 |
+
alpha: np.ndarray,
|
| 87 |
+
mask_array: np.ndarray,
|
| 88 |
+
orig_array: np.ndarray,
|
| 89 |
+
threshold: float = 0.45
|
| 90 |
+
) -> np.ndarray:
|
| 91 |
+
"""
|
| 92 |
+
Binarize semi-transparent edge pixels to eliminate color bleeding.
|
| 93 |
+
|
| 94 |
+
Semi-transparent pixels at edges cause visible contamination because
|
| 95 |
+
they blend the original (potentially dark) foreground with the new
|
| 96 |
+
background. This method forces edge pixels to be either fully opaque
|
| 97 |
+
or fully transparent.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
alpha: Current alpha channel (float32, 0.0-1.0)
|
| 101 |
+
mask_array: Original mask array (uint8, 0-255)
|
| 102 |
+
orig_array: Original foreground image array (uint8, RGB)
|
| 103 |
+
threshold: Alpha threshold for binarization decision (default 0.45)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Modified alpha array with binarized edges (float32)
|
| 107 |
+
"""
|
| 108 |
+
# Identify semi-transparent edge zone (not fully opaque, not fully transparent)
|
| 109 |
+
edge_zone = (alpha > 0.05) & (alpha < 0.95)
|
| 110 |
+
|
| 111 |
+
if not np.any(edge_zone):
|
| 112 |
+
return alpha
|
| 113 |
+
|
| 114 |
+
# Calculate local foreground luminance for adaptive thresholding
|
| 115 |
+
gray = np.mean(orig_array, axis=2)
|
| 116 |
+
|
| 117 |
+
# For dark foreground pixels, use slightly higher threshold
|
| 118 |
+
# to preserve more of the dark subject
|
| 119 |
+
is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
|
| 120 |
+
|
| 121 |
+
# Create adaptive threshold map
|
| 122 |
+
adaptive_threshold = np.full_like(alpha, threshold)
|
| 123 |
+
adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
|
| 124 |
+
|
| 125 |
+
# Binarize: above threshold -> opaque, below -> transparent
|
| 126 |
+
alpha_binarized = alpha.copy()
|
| 127 |
+
|
| 128 |
+
# Pixels above threshold become fully opaque
|
| 129 |
+
make_opaque = edge_zone & (alpha > adaptive_threshold)
|
| 130 |
+
alpha_binarized[make_opaque] = 1.0
|
| 131 |
+
|
| 132 |
+
# Pixels below threshold become fully transparent
|
| 133 |
+
make_transparent = edge_zone & (alpha <= adaptive_threshold)
|
| 134 |
+
alpha_binarized[make_transparent] = 0.0
|
| 135 |
+
|
| 136 |
+
# Log statistics
|
| 137 |
+
num_opaque = np.sum(make_opaque)
|
| 138 |
+
num_transparent = np.sum(make_transparent)
|
| 139 |
+
logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
|
| 140 |
+
|
| 141 |
+
return alpha_binarized
|
| 142 |
+
|
| 143 |
+
def _apply_edge_cleanup(
|
| 144 |
+
self,
|
| 145 |
+
result_array: np.ndarray,
|
| 146 |
+
bg_array: np.ndarray,
|
| 147 |
+
alpha: np.ndarray,
|
| 148 |
+
cleanup_width: int = 2
|
| 149 |
+
) -> np.ndarray:
|
| 150 |
+
"""
|
| 151 |
+
Final cleanup pass to remove any remaining edge artifacts.
|
| 152 |
+
|
| 153 |
+
Detects remaining semi-transparent edges and replaces them with
|
| 154 |
+
either pure foreground or pure background colors.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
result_array: Current blended result (uint8, RGB)
|
| 158 |
+
bg_array: Background image array (uint8, RGB)
|
| 159 |
+
alpha: Final alpha channel (float32, 0.0-1.0)
|
| 160 |
+
cleanup_width: Width of edge zone to clean (default 2)
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Cleaned result array (uint8)
|
| 164 |
+
"""
|
| 165 |
+
# Find edge pixels that might still have artifacts
|
| 166 |
+
# These are pixels with alpha close to but not exactly 0 or 1
|
| 167 |
+
residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
|
| 168 |
+
|
| 169 |
+
if not np.any(residual_edge):
|
| 170 |
+
return result_array
|
| 171 |
+
|
| 172 |
+
result_cleaned = result_array.copy()
|
| 173 |
+
|
| 174 |
+
# For residual edge pixels, snap to nearest pure state
|
| 175 |
+
snap_to_bg = residual_edge & (alpha < 0.5)
|
| 176 |
+
snap_to_fg = residual_edge & (alpha >= 0.5)
|
| 177 |
+
|
| 178 |
+
# Replace with background
|
| 179 |
+
result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
|
| 180 |
+
|
| 181 |
+
# For foreground, keep original but ensure no blending artifacts
|
| 182 |
+
# (already handled by the blend, so no action needed for snap_to_fg)
|
| 183 |
+
|
| 184 |
+
num_cleaned = np.sum(residual_edge)
|
| 185 |
+
if num_cleaned > 0:
|
| 186 |
+
logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
|
| 187 |
+
|
| 188 |
+
return result_cleaned
|
| 189 |
+
|
| 190 |
+
def _remove_background_color_contamination(
|
| 191 |
+
self,
|
| 192 |
+
image_array: np.ndarray,
|
| 193 |
+
mask_array: np.ndarray,
|
| 194 |
+
orig_bg_color_lab: np.ndarray,
|
| 195 |
+
tolerance: float = 30.0
|
| 196 |
+
) -> np.ndarray:
|
| 197 |
+
"""
|
| 198 |
+
Remove original background color contamination from foreground pixels.
|
| 199 |
+
|
| 200 |
+
Scans the foreground area for pixels that match the original background
|
| 201 |
+
color and replaces them with nearby clean foreground colors.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
image_array: Foreground image array (uint8, RGB)
|
| 205 |
+
mask_array: Mask array (uint8, 0-255)
|
| 206 |
+
orig_bg_color_lab: Original background color in Lab space
|
| 207 |
+
tolerance: DeltaE tolerance for detecting contaminated pixels
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Cleaned image array (uint8)
|
| 211 |
+
"""
|
| 212 |
+
# Convert to Lab for color comparison
|
| 213 |
+
image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 214 |
+
|
| 215 |
+
# Only process foreground pixels (mask > 50)
|
| 216 |
+
foreground_mask = mask_array > 50
|
| 217 |
+
|
| 218 |
+
if not np.any(foreground_mask):
|
| 219 |
+
return image_array
|
| 220 |
+
|
| 221 |
+
# Calculate deltaE from original background color for all pixels
|
| 222 |
+
delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
|
| 223 |
+
delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
|
| 224 |
+
delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
|
| 225 |
+
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
|
| 226 |
+
|
| 227 |
+
# Find contaminated pixels: in foreground but color similar to original background
|
| 228 |
+
contaminated = foreground_mask & (delta_e < tolerance)
|
| 229 |
+
|
| 230 |
+
if not np.any(contaminated):
|
| 231 |
+
logger.debug("No background color contamination detected in foreground")
|
| 232 |
+
return image_array
|
| 233 |
+
|
| 234 |
+
num_contaminated = np.sum(contaminated)
|
| 235 |
+
logger.info(f"Found {num_contaminated} pixels with background color contamination")
|
| 236 |
+
|
| 237 |
+
# Create output array
|
| 238 |
+
result = image_array.copy()
|
| 239 |
+
|
| 240 |
+
# For contaminated pixels, use inpainting to replace with surrounding colors
|
| 241 |
+
inpaint_mask = contaminated.astype(np.uint8) * 255
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
# Use inpainting to fill contaminated areas with surrounding foreground colors
|
| 245 |
+
result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
| 246 |
+
logger.info(f"Inpainted {num_contaminated} contaminated pixels")
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.warning(f"Inpainting failed: {e}, using median filter fallback")
|
| 249 |
+
# Fallback: apply median filter to contaminated areas
|
| 250 |
+
median_filtered = cv2.medianBlur(image_array, 5)
|
| 251 |
+
result[contaminated] = median_filtered[contaminated]
|
| 252 |
+
|
| 253 |
+
return result
|
| 254 |
+
|
| 255 |
+
def _protect_foreground_core(
|
| 256 |
+
self,
|
| 257 |
+
result_array: np.ndarray,
|
| 258 |
+
orig_array: np.ndarray,
|
| 259 |
+
mask_array: np.ndarray,
|
| 260 |
+
protection_threshold: int = 140
|
| 261 |
+
) -> np.ndarray:
|
| 262 |
+
"""
|
| 263 |
+
Strongly protect core foreground pixels from any background influence.
|
| 264 |
+
|
| 265 |
+
For pixels with high mask confidence, directly use the original foreground
|
| 266 |
+
color without any blending, ensuring faces and bodies are not affected.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
result_array: Current blended result (uint8, RGB)
|
| 270 |
+
orig_array: Original foreground image (uint8, RGB)
|
| 271 |
+
mask_array: Mask array (uint8, 0-255)
|
| 272 |
+
protection_threshold: Mask value above which pixels are fully protected
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Protected result array (uint8)
|
| 276 |
+
"""
|
| 277 |
+
# Identify strongly protected foreground pixels
|
| 278 |
+
strong_foreground = mask_array >= protection_threshold
|
| 279 |
+
|
| 280 |
+
if not np.any(strong_foreground):
|
| 281 |
+
return result_array
|
| 282 |
+
|
| 283 |
+
# For these pixels, use original foreground color directly
|
| 284 |
+
result_protected = result_array.copy()
|
| 285 |
+
result_protected[strong_foreground] = orig_array[strong_foreground]
|
| 286 |
+
|
| 287 |
+
num_protected = np.sum(strong_foreground)
|
| 288 |
+
logger.info(f"Protected {num_protected} core foreground pixels from background influence")
|
| 289 |
+
|
| 290 |
+
return result_protected
|
| 291 |
+
|
| 292 |
+
def multi_scale_edge_refinement(
|
| 293 |
+
self,
|
| 294 |
+
original_image: Image.Image,
|
| 295 |
+
background_image: Image.Image,
|
| 296 |
+
mask: Image.Image
|
| 297 |
+
) -> Image.Image:
|
| 298 |
+
"""
|
| 299 |
+
Multi-scale edge refinement for better edge quality.
|
| 300 |
+
Uses image pyramid to handle edges at different scales.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
original_image: Foreground PIL Image
|
| 304 |
+
background_image: Background PIL Image
|
| 305 |
+
mask: Current mask PIL Image
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Refined mask PIL Image
|
| 309 |
+
"""
|
| 310 |
+
logger.info("🔍 Starting multi-scale edge refinement...")
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
# Convert to numpy arrays
|
| 314 |
+
orig_array = np.array(original_image.convert('RGB'))
|
| 315 |
+
mask_array = np.array(mask).astype(np.float32)
|
| 316 |
+
height, width = mask_array.shape
|
| 317 |
+
|
| 318 |
+
# Define scales for pyramid
|
| 319 |
+
scales = [1.0, 0.5, 0.25] # Original, half, quarter
|
| 320 |
+
scale_masks = []
|
| 321 |
+
scale_complexities = []
|
| 322 |
+
|
| 323 |
+
# Convert to grayscale for edge detection
|
| 324 |
+
gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
|
| 325 |
+
|
| 326 |
+
for scale in scales:
|
| 327 |
+
if scale == 1.0:
|
| 328 |
+
scaled_gray = gray
|
| 329 |
+
scaled_mask = mask_array
|
| 330 |
+
else:
|
| 331 |
+
new_h = int(height * scale)
|
| 332 |
+
new_w = int(width * scale)
|
| 333 |
+
scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 334 |
+
scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 335 |
+
|
| 336 |
+
# Compute local complexity using gradient standard deviation
|
| 337 |
+
sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 338 |
+
sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 339 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 340 |
+
|
| 341 |
+
# Calculate local complexity in 5x5 regions
|
| 342 |
+
kernel_size = 5
|
| 343 |
+
complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
|
| 344 |
+
|
| 345 |
+
# Resize back to original size
|
| 346 |
+
if scale != 1.0:
|
| 347 |
+
scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
| 348 |
+
complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
| 349 |
+
|
| 350 |
+
scale_masks.append(scaled_mask)
|
| 351 |
+
scale_complexities.append(complexity)
|
| 352 |
+
|
| 353 |
+
# Compute weights based on complexity
|
| 354 |
+
# High complexity -> use high resolution mask
|
| 355 |
+
# Low complexity -> use low resolution mask (smoother)
|
| 356 |
+
weights = np.zeros((len(scales), height, width), dtype=np.float32)
|
| 357 |
+
|
| 358 |
+
# Normalize complexities
|
| 359 |
+
max_complexity = max(c.max() for c in scale_complexities) + 1e-6
|
| 360 |
+
normalized_complexities = [c / max_complexity for c in scale_complexities]
|
| 361 |
+
|
| 362 |
+
# Weight assignment: higher complexity at each scale means that scale is more reliable
|
| 363 |
+
for i, complexity in enumerate(normalized_complexities):
|
| 364 |
+
if i == 0: # High resolution - prefer for high complexity regions
|
| 365 |
+
weights[i] = complexity
|
| 366 |
+
elif i == 1: # Medium resolution - moderate complexity
|
| 367 |
+
weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
|
| 368 |
+
else: # Low resolution - prefer for low complexity regions
|
| 369 |
+
weights[i] = 1 - complexity
|
| 370 |
+
|
| 371 |
+
# Normalize weights so they sum to 1 at each pixel
|
| 372 |
+
weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
|
| 373 |
+
weights = weights / weight_sum
|
| 374 |
+
|
| 375 |
+
# Weighted blend of masks from different scales
|
| 376 |
+
refined_mask = np.zeros((height, width), dtype=np.float32)
|
| 377 |
+
for i, mask_i in enumerate(scale_masks):
|
| 378 |
+
refined_mask += weights[i] * mask_i
|
| 379 |
+
|
| 380 |
+
# Clip and convert to uint8
|
| 381 |
+
refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
|
| 382 |
+
|
| 383 |
+
logger.info("✅ Multi-scale edge refinement completed")
|
| 384 |
+
return Image.fromarray(refined_mask, mode='L')
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
|
| 388 |
+
return mask
|
| 389 |
+
|
| 390 |
+
def simple_blend_images(
|
| 391 |
+
self,
|
| 392 |
+
original_image: Image.Image,
|
| 393 |
+
background_image: Image.Image,
|
| 394 |
+
combination_mask: Image.Image,
|
| 395 |
+
use_multi_scale: Optional[bool] = None,
|
| 396 |
+
feather_radius: int = 0
|
| 397 |
+
) -> Image.Image:
|
| 398 |
+
"""
|
| 399 |
+
Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
original_image: Foreground PIL Image
|
| 403 |
+
background_image: Background PIL Image
|
| 404 |
+
combination_mask: Mask PIL Image (L mode)
|
| 405 |
+
use_multi_scale: Override for multi-scale refinement (None = use class default)
|
| 406 |
+
feather_radius: Gaussian blur radius for mask feathering (0 = disabled, default behavior)
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Blended PIL Image
|
| 410 |
+
"""
|
| 411 |
+
logger.info("🎨 Starting advanced image blending process...")
|
| 412 |
+
|
| 413 |
+
# Apply multi-scale edge refinement if enabled
|
| 414 |
+
should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
|
| 415 |
+
if should_use_multi_scale:
|
| 416 |
+
combination_mask = self.multi_scale_edge_refinement(
|
| 417 |
+
original_image, background_image, combination_mask
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Convert to numpy arrays
|
| 421 |
+
orig_array = np.array(original_image, dtype=np.uint8)
|
| 422 |
+
bg_array = np.array(background_image, dtype=np.uint8)
|
| 423 |
+
mask_array = np.array(combination_mask, dtype=np.uint8)
|
| 424 |
+
|
| 425 |
+
# Apply feathering if requested
|
| 426 |
+
if feather_radius > 0:
|
| 427 |
+
kernel_size = feather_radius * 2 + 1
|
| 428 |
+
mask_array = cv2.GaussianBlur(
|
| 429 |
+
mask_array,
|
| 430 |
+
(kernel_size, kernel_size),
|
| 431 |
+
feather_radius / 2.0
|
| 432 |
+
)
|
| 433 |
+
logger.info(f"📐 Mask feathering applied: radius={feather_radius}, kernel={kernel_size}x{kernel_size}")
|
| 434 |
+
|
| 435 |
+
logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
|
| 436 |
+
logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
|
| 437 |
+
|
| 438 |
+
# === NEW: Apply mask erosion to remove contaminated edge pixels ===
|
| 439 |
+
mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
|
| 440 |
+
logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
|
| 441 |
+
|
| 442 |
+
# Enhanced parameters for better spill suppression
|
| 443 |
+
RING_WIDTH_PX = 4 # Increased ring width for better coverage
|
| 444 |
+
SPILL_STRENGTH = 0.85 # Stronger spill suppression
|
| 445 |
+
L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
|
| 446 |
+
DELTAE_THRESHOLD = 18 # More aggressive contamination detection
|
| 447 |
+
HARD_EDGE_PROTECT = True # Black edge protection
|
| 448 |
+
INPAINT_FALLBACK = True # inpaint fallback repair
|
| 449 |
+
MULTI_PASS_CORRECTION = True # Enable multi-pass correction
|
| 450 |
+
|
| 451 |
+
# Estimate original background color and foreground representative color ===
|
| 452 |
+
height, width = orig_array.shape[:2]
|
| 453 |
+
|
| 454 |
+
# Take 15px from each side to estimate original background color
|
| 455 |
+
edge_width = 15
|
| 456 |
+
border_pixels = []
|
| 457 |
+
|
| 458 |
+
# Collect border pixels (excluding foreground areas)
|
| 459 |
+
border_mask = np.zeros((height, width), dtype=bool)
|
| 460 |
+
border_mask[:edge_width, :] = True # Top edge
|
| 461 |
+
border_mask[-edge_width:, :] = True # Bottom edge
|
| 462 |
+
border_mask[:, :edge_width] = True # Left edge
|
| 463 |
+
border_mask[:, -edge_width:] = True # Right edge
|
| 464 |
+
|
| 465 |
+
# Exclude foreground areas
|
| 466 |
+
fg_binary = mask_array > 50
|
| 467 |
+
border_mask = border_mask & (~fg_binary)
|
| 468 |
+
|
| 469 |
+
if np.any(border_mask):
|
| 470 |
+
border_pixels = orig_array[border_mask].reshape(-1, 3)
|
| 471 |
+
|
| 472 |
+
# Simplified background color estimation (no sklearn dependency)
|
| 473 |
+
try:
|
| 474 |
+
if len(border_pixels) > 100:
|
| 475 |
+
# Use histogram to find mode colors
|
| 476 |
+
# Quantize RGB to coarser grid to find main colors
|
| 477 |
+
quantized = (border_pixels // 32) * 32 # 8-level quantization
|
| 478 |
+
|
| 479 |
+
# Find most frequent color
|
| 480 |
+
unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
|
| 481 |
+
axis=0, return_counts=True)
|
| 482 |
+
most_common_idx = np.argmax(counts)
|
| 483 |
+
orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
|
| 484 |
+
else:
|
| 485 |
+
orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
|
| 486 |
+
except:
|
| 487 |
+
# Fallback: use four corners average
|
| 488 |
+
corners = np.array([orig_array[0,0], orig_array[0,-1],
|
| 489 |
+
orig_array[-1,0], orig_array[-1,-1]])
|
| 490 |
+
orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
|
| 491 |
+
else:
|
| 492 |
+
orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
|
| 493 |
+
|
| 494 |
+
# Convert to Lab space
|
| 495 |
+
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 496 |
+
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
|
| 497 |
+
|
| 498 |
+
# Remove original background color contamination from foreground
|
| 499 |
+
orig_array = self._remove_background_color_contamination(
|
| 500 |
+
orig_array,
|
| 501 |
+
mask_array,
|
| 502 |
+
orig_bg_color_lab,
|
| 503 |
+
tolerance=self.BACKGROUND_COLOR_TOLERANCE
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Redefine trimap, optimized for cartoon characters
|
| 507 |
+
try:
|
| 508 |
+
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 509 |
+
|
| 510 |
+
# FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
|
| 511 |
+
mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
|
| 512 |
+
fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
|
| 513 |
+
|
| 514 |
+
# RING: Use morphological gradient to redefine, ensuring only thin edge band
|
| 515 |
+
mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
|
| 516 |
+
mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
|
| 517 |
+
|
| 518 |
+
# Ensure consistent data types to avoid overflow
|
| 519 |
+
morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
|
| 520 |
+
ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
|
| 521 |
+
|
| 522 |
+
# BG: background area
|
| 523 |
+
bg_zone = mask_array < 30
|
| 524 |
+
|
| 525 |
+
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
|
| 526 |
+
|
| 527 |
+
except Exception as e:
|
| 528 |
+
logger.error(f"❌ Trimap definition failed: {e}")
|
| 529 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 530 |
+
print(f"❌ TRIMAP ERROR: {e}")
|
| 531 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 532 |
+
# Fallback to simple definition
|
| 533 |
+
fg_core = mask_array > 200
|
| 534 |
+
ring_zone = (mask_array > 50) & (mask_array <= 200)
|
| 535 |
+
bg_zone = mask_array <= 50
|
| 536 |
+
|
| 537 |
+
# Foreground representative color: estimated from FG_CORE
|
| 538 |
+
if np.any(fg_core):
|
| 539 |
+
fg_pixels = orig_array[fg_core].reshape(-1, 3)
|
| 540 |
+
fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
|
| 541 |
+
else:
|
| 542 |
+
fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
|
| 543 |
+
|
| 544 |
+
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 545 |
+
|
| 546 |
+
# Edge band spill suppression and repair
|
| 547 |
+
if np.any(ring_zone):
|
| 548 |
+
# Convert to Lab space
|
| 549 |
+
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 550 |
+
orig_array_working = orig_array.copy().astype(np.float32)
|
| 551 |
+
|
| 552 |
+
# ΔE detect contaminated pixels
|
| 553 |
+
ring_pixels_lab = orig_lab[ring_zone]
|
| 554 |
+
|
| 555 |
+
# Calculate ΔE with original background color (simplified version)
|
| 556 |
+
delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
|
| 557 |
+
delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
|
| 558 |
+
delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
|
| 559 |
+
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
|
| 560 |
+
|
| 561 |
+
# Contaminated pixel mask
|
| 562 |
+
contaminated_mask = delta_e < DELTAE_THRESHOLD
|
| 563 |
+
|
| 564 |
+
if np.any(contaminated_mask):
|
| 565 |
+
# Calculate adaptive strength based on delta_e for each pixel
|
| 566 |
+
# Pixels closer to background color get stronger correction
|
| 567 |
+
contaminated_delta_e = delta_e[contaminated_mask]
|
| 568 |
+
|
| 569 |
+
# Adaptive strength formula: inverse relationship with delta_e
|
| 570 |
+
# Pixels very close to bg color (low delta_e) -> strong correction
|
| 571 |
+
# Pixels further from bg color (high delta_e) -> lighter correction
|
| 572 |
+
adaptive_strength = SPILL_STRENGTH * np.maximum(
|
| 573 |
+
0.0,
|
| 574 |
+
1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Clamp adaptive strength to reasonable range (30% - 100% of base strength)
|
| 578 |
+
min_strength = SPILL_STRENGTH * 0.3
|
| 579 |
+
adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
|
| 580 |
+
|
| 581 |
+
# Store for debug visualization
|
| 582 |
+
self._adaptive_strength_map = np.zeros_like(delta_e)
|
| 583 |
+
self._adaptive_strength_map[contaminated_mask] = adaptive_strength
|
| 584 |
+
|
| 585 |
+
logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
|
| 586 |
+
|
| 587 |
+
# Chroma vector deprojection
|
| 588 |
+
bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
|
| 589 |
+
bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
|
| 590 |
+
|
| 591 |
+
# Color correction for contaminated pixels
|
| 592 |
+
contaminated_pixels = ring_pixels_lab[contaminated_mask]
|
| 593 |
+
|
| 594 |
+
# Remove background chroma component with adaptive strength (per-pixel)
|
| 595 |
+
pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
|
| 596 |
+
projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
|
| 597 |
+
|
| 598 |
+
# Apply adaptive strength per pixel
|
| 599 |
+
adaptive_strength_2d = adaptive_strength[:, np.newaxis]
|
| 600 |
+
corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
|
| 601 |
+
|
| 602 |
+
# Converge toward foreground representative color with adaptive strength
|
| 603 |
+
convergence_factor = adaptive_strength_2d * 0.6
|
| 604 |
+
corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
|
| 605 |
+
fg_rep_color_lab[1:3] * convergence_factor)
|
| 606 |
+
|
| 607 |
+
# Adaptive luminance matching
|
| 608 |
+
adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
|
| 609 |
+
corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
|
| 610 |
+
fg_rep_color_lab[0] * adaptive_l_strength)
|
| 611 |
+
|
| 612 |
+
# Update Lab values
|
| 613 |
+
ring_pixels_lab[contaminated_mask, 0] = corrected_l
|
| 614 |
+
ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
|
| 615 |
+
|
| 616 |
+
# Write back to original image
|
| 617 |
+
orig_lab[ring_zone] = ring_pixels_lab
|
| 618 |
+
|
| 619 |
+
# Dark edge protection
|
| 620 |
+
if HARD_EDGE_PROTECT:
|
| 621 |
+
gray = np.mean(orig_array, axis=2)
|
| 622 |
+
# Detect dark and high gradient areas
|
| 623 |
+
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 624 |
+
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 625 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 626 |
+
|
| 627 |
+
dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
|
| 628 |
+
# Protect these areas from excessive modification, copy directly from original
|
| 629 |
+
if np.any(dark_edge_zone):
|
| 630 |
+
orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
|
| 631 |
+
|
| 632 |
+
# Multi-pass correction for stubborn spill
|
| 633 |
+
if MULTI_PASS_CORRECTION:
|
| 634 |
+
# Second pass for remaining contamination
|
| 635 |
+
ring_pixels_lab_pass2 = orig_lab[ring_zone]
|
| 636 |
+
delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
|
| 637 |
+
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
|
| 638 |
+
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
|
| 639 |
+
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
|
| 640 |
+
|
| 641 |
+
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
|
| 642 |
+
|
| 643 |
+
if np.any(still_contaminated):
|
| 644 |
+
# Apply stronger correction to remaining contaminated pixels
|
| 645 |
+
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
|
| 646 |
+
|
| 647 |
+
# More aggressive chroma neutralization
|
| 648 |
+
remaining_chroma = remaining_pixels[:, 1:3]
|
| 649 |
+
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
|
| 650 |
+
|
| 651 |
+
# Stronger luminance matching
|
| 652 |
+
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
|
| 653 |
+
|
| 654 |
+
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
|
| 655 |
+
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
|
| 656 |
+
orig_lab[ring_zone] = ring_pixels_lab_pass2
|
| 657 |
+
|
| 658 |
+
# Convert back to RGB
|
| 659 |
+
orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
|
| 660 |
+
orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
|
| 661 |
+
|
| 662 |
+
# inpaint fallback repair
|
| 663 |
+
if INPAINT_FALLBACK:
|
| 664 |
+
# inpaint still contaminated outermost pixels
|
| 665 |
+
final_contaminated = ring_zone.copy()
|
| 666 |
+
|
| 667 |
+
# Check if there's still contamination after repair
|
| 668 |
+
final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 669 |
+
final_ring_lab = final_lab[ring_zone]
|
| 670 |
+
final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
|
| 671 |
+
final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
|
| 672 |
+
final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
|
| 673 |
+
final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
|
| 674 |
+
|
| 675 |
+
still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
|
| 676 |
+
if np.any(still_contaminated):
|
| 677 |
+
# Create inpaint mask
|
| 678 |
+
inpaint_mask = np.zeros((height, width), dtype=np.uint8)
|
| 679 |
+
ring_coords = np.where(ring_zone)
|
| 680 |
+
inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
|
| 681 |
+
inpaint_mask[inpaint_coords] = 255
|
| 682 |
+
|
| 683 |
+
# Execute inpaint
|
| 684 |
+
try:
|
| 685 |
+
orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
|
| 686 |
+
except:
|
| 687 |
+
# Fallback: directly cover with foreground representative color
|
| 688 |
+
orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
|
| 689 |
+
|
| 690 |
+
orig_array = orig_array_corrected
|
| 691 |
+
|
| 692 |
+
# === Linear space blending (keep original logic) ===
|
| 693 |
+
def srgb_to_linear(img):
|
| 694 |
+
img_f = img.astype(np.float32) / 255.0
|
| 695 |
+
return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
|
| 696 |
+
|
| 697 |
+
def linear_to_srgb(img):
|
| 698 |
+
img_clipped = np.clip(img, 0, 1)
|
| 699 |
+
return np.where(img_clipped <= 0.0031308,
|
| 700 |
+
12.92 * img_clipped,
|
| 701 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055)
|
| 702 |
+
|
| 703 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 704 |
+
bg_linear = srgb_to_linear(bg_array)
|
| 705 |
+
|
| 706 |
+
# Cartoon-optimized Alpha calculation
|
| 707 |
+
alpha = mask_array.astype(np.float32) / 255.0
|
| 708 |
+
|
| 709 |
+
# Core foreground region - fully opaque
|
| 710 |
+
alpha[fg_core] = 1.0
|
| 711 |
+
|
| 712 |
+
# Background region - fully transparent
|
| 713 |
+
alpha[bg_zone] = 0.0
|
| 714 |
+
|
| 715 |
+
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
|
| 716 |
+
high_confidence_pixels = mask_array >= 160
|
| 717 |
+
alpha[high_confidence_pixels] = 1.0
|
| 718 |
+
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
|
| 719 |
+
|
| 720 |
+
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
|
| 721 |
+
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
|
| 722 |
+
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
|
| 723 |
+
|
| 724 |
+
# Retain existing black outline/strong edge protection
|
| 725 |
+
orig_gray = np.mean(orig_array, axis=2)
|
| 726 |
+
|
| 727 |
+
# Detect strong edge areas
|
| 728 |
+
sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 729 |
+
sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 730 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 731 |
+
|
| 732 |
+
# Black outline/strong edge protection: nearly fully opaque
|
| 733 |
+
black_edge_threshold = 60 # black edge threshold
|
| 734 |
+
gradient_threshold = 25 # gradient threshold
|
| 735 |
+
strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
|
| 736 |
+
alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
|
| 737 |
+
|
| 738 |
+
logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
|
| 739 |
+
|
| 740 |
+
# Apply edge alpha binarization to eliminate semi-transparent artifacts
|
| 741 |
+
alpha = self._binarize_edge_alpha(
|
| 742 |
+
alpha,
|
| 743 |
+
mask_array,
|
| 744 |
+
orig_array,
|
| 745 |
+
threshold=self.ALPHA_BINARIZE_THRESHOLD
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Final blending
|
| 749 |
+
alpha_3d = alpha[:, :, np.newaxis]
|
| 750 |
+
result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
|
| 751 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 752 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 753 |
+
|
| 754 |
+
# Final edge cleanup pass
|
| 755 |
+
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
|
| 756 |
+
|
| 757 |
+
# Protect core foreground from any background influence
|
| 758 |
+
# This ensures faces and bodies retain original colors
|
| 759 |
+
result_array = self._protect_foreground_core(
|
| 760 |
+
result_array,
|
| 761 |
+
np.array(original_image, dtype=np.uint8), # Use original unprocessed image
|
| 762 |
+
mask_array,
|
| 763 |
+
protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# Store debug information (for debug output)
|
| 767 |
+
self._debug_info = {
|
| 768 |
+
'orig_bg_color_rgb': orig_bg_color_rgb,
|
| 769 |
+
'fg_rep_color_rgb': fg_rep_color_rgb,
|
| 770 |
+
'orig_bg_color_lab': orig_bg_color_lab,
|
| 771 |
+
'fg_rep_color_lab': fg_rep_color_lab,
|
| 772 |
+
'ring_zone': ring_zone,
|
| 773 |
+
'fg_core': fg_core,
|
| 774 |
+
'alpha_final': alpha
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
return Image.fromarray(result_array)
|
| 778 |
+
|
| 779 |
+
def create_debug_images(
|
| 780 |
+
self,
|
| 781 |
+
original_image: Image.Image,
|
| 782 |
+
generated_background: Image.Image,
|
| 783 |
+
combination_mask: Image.Image,
|
| 784 |
+
combined_image: Image.Image
|
| 785 |
+
) -> Dict[str, Image.Image]:
|
| 786 |
+
"""
|
| 787 |
+
Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
|
| 788 |
+
"""
|
| 789 |
+
debug_images = {}
|
| 790 |
+
|
| 791 |
+
# Final Mask grayscale
|
| 792 |
+
debug_images["mask_gray"] = combination_mask.convert('L')
|
| 793 |
+
|
| 794 |
+
# Alpha Heatmap
|
| 795 |
+
mask_array = np.array(combination_mask.convert('L'))
|
| 796 |
+
heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
|
| 797 |
+
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
|
| 798 |
+
debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
|
| 799 |
+
|
| 800 |
+
# Ring visualization overlay - show ring areas on original image
|
| 801 |
+
if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
|
| 802 |
+
ring_zone = self._debug_info['ring_zone']
|
| 803 |
+
orig_array = np.array(original_image)
|
| 804 |
+
ring_overlay = orig_array.copy()
|
| 805 |
+
|
| 806 |
+
# Mark ring areas with red semi-transparent overlay
|
| 807 |
+
ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
|
| 808 |
+
debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
|
| 809 |
+
else:
|
| 810 |
+
# If no ring information, use original image
|
| 811 |
+
debug_images["ring_visualization"] = original_image
|
| 812 |
+
|
| 813 |
+
# Adaptive strength heatmap - visualize per-pixel correction strength
|
| 814 |
+
if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
|
| 815 |
+
# Normalize adaptive strength to 0-255 for visualization
|
| 816 |
+
strength_map = self._adaptive_strength_map
|
| 817 |
+
if strength_map.max() > 0:
|
| 818 |
+
normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
|
| 819 |
+
else:
|
| 820 |
+
normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
|
| 821 |
+
|
| 822 |
+
# Apply colormap
|
| 823 |
+
strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
|
| 824 |
+
strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
|
| 825 |
+
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
|
| 826 |
+
|
| 827 |
+
return debug_images
|
| 828 |
+
|
| 829 |
+
# INPAINTING-SPECIFIC BLENDING METHODS
|
| 830 |
+
def blend_inpainting(
|
| 831 |
+
self,
|
| 832 |
+
original: Image.Image,
|
| 833 |
+
generated: Image.Image,
|
| 834 |
+
mask: Image.Image,
|
| 835 |
+
feather_radius: int = 8,
|
| 836 |
+
apply_color_correction: bool = True
|
| 837 |
+
) -> Image.Image:
|
| 838 |
+
"""
|
| 839 |
+
Blend inpainted region with original image.
|
| 840 |
+
|
| 841 |
+
Specialized blending for inpainting that focuses on seamless integration
|
| 842 |
+
rather than foreground protection. Performs blending in linear color space
|
| 843 |
+
with optional adaptive color correction at boundaries.
|
| 844 |
+
|
| 845 |
+
Parameters
|
| 846 |
+
----------
|
| 847 |
+
original : PIL.Image
|
| 848 |
+
Original image before inpainting
|
| 849 |
+
generated : PIL.Image
|
| 850 |
+
Generated/inpainted result from the model
|
| 851 |
+
mask : PIL.Image
|
| 852 |
+
Inpainting mask (white = inpainted area)
|
| 853 |
+
feather_radius : int
|
| 854 |
+
Feathering radius for smooth transitions
|
| 855 |
+
apply_color_correction : bool
|
| 856 |
+
Whether to apply adaptive color correction at boundaries
|
| 857 |
+
|
| 858 |
+
Returns
|
| 859 |
+
-------
|
| 860 |
+
PIL.Image
|
| 861 |
+
Blended result
|
| 862 |
+
"""
|
| 863 |
+
logger.info(f"Inpainting blend: feather={feather_radius}, color_correction={apply_color_correction}")
|
| 864 |
+
|
| 865 |
+
# Ensure same size
|
| 866 |
+
if generated.size != original.size:
|
| 867 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 868 |
+
if mask.size != original.size:
|
| 869 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 870 |
+
|
| 871 |
+
# Convert to arrays
|
| 872 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 873 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 874 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 875 |
+
|
| 876 |
+
# Apply feathering to mask
|
| 877 |
+
if feather_radius > 0:
|
| 878 |
+
scaled_radius = int(feather_radius * self.INPAINT_FEATHER_SCALE)
|
| 879 |
+
kernel_size = scaled_radius * 2 + 1
|
| 880 |
+
mask_array = cv2.GaussianBlur(
|
| 881 |
+
mask_array,
|
| 882 |
+
(kernel_size, kernel_size),
|
| 883 |
+
scaled_radius / 2
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
# Apply adaptive color correction if enabled
|
| 887 |
+
if apply_color_correction:
|
| 888 |
+
gen_array = self._apply_inpaint_color_correction(
|
| 889 |
+
orig_array, gen_array, mask_array
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
# sRGB to linear conversion for accurate blending
|
| 893 |
+
def srgb_to_linear(img):
|
| 894 |
+
img_norm = img / 255.0
|
| 895 |
+
return np.where(
|
| 896 |
+
img_norm <= 0.04045,
|
| 897 |
+
img_norm / 12.92,
|
| 898 |
+
np.power((img_norm + 0.055) / 1.055, 2.4)
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
def linear_to_srgb(img):
|
| 902 |
+
img_clipped = np.clip(img, 0, 1)
|
| 903 |
+
return np.where(
|
| 904 |
+
img_clipped <= 0.0031308,
|
| 905 |
+
12.92 * img_clipped,
|
| 906 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Convert to linear space
|
| 910 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 911 |
+
gen_linear = srgb_to_linear(gen_array)
|
| 912 |
+
|
| 913 |
+
# Alpha blending in linear space
|
| 914 |
+
alpha = mask_array[:, :, np.newaxis]
|
| 915 |
+
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
|
| 916 |
+
|
| 917 |
+
# Convert back to sRGB
|
| 918 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 919 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 920 |
+
|
| 921 |
+
logger.debug("Inpainting blend completed in linear color space")
|
| 922 |
+
|
| 923 |
+
return Image.fromarray(result_array)
|
| 924 |
+
|
| 925 |
+
def _apply_inpaint_color_correction(
|
| 926 |
+
self,
|
| 927 |
+
original: np.ndarray,
|
| 928 |
+
generated: np.ndarray,
|
| 929 |
+
mask: np.ndarray
|
| 930 |
+
) -> np.ndarray:
|
| 931 |
+
"""
|
| 932 |
+
Apply adaptive color correction to match generated region with surroundings.
|
| 933 |
+
|
| 934 |
+
Analyzes the boundary region and adjusts the generated content's
|
| 935 |
+
luminance and color to better match the original context.
|
| 936 |
+
|
| 937 |
+
Parameters
|
| 938 |
+
----------
|
| 939 |
+
original : np.ndarray
|
| 940 |
+
Original image (float32, 0-255)
|
| 941 |
+
generated : np.ndarray
|
| 942 |
+
Generated image (float32, 0-255)
|
| 943 |
+
mask : np.ndarray
|
| 944 |
+
Blend mask (float32, 0-1)
|
| 945 |
+
|
| 946 |
+
Returns
|
| 947 |
+
-------
|
| 948 |
+
np.ndarray
|
| 949 |
+
Color-corrected generated image
|
| 950 |
+
"""
|
| 951 |
+
# Find boundary region
|
| 952 |
+
mask_binary = (mask > 0.5).astype(np.uint8)
|
| 953 |
+
kernel = cv2.getStructuringElement(
|
| 954 |
+
cv2.MORPH_ELLIPSE,
|
| 955 |
+
(self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1, self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1)
|
| 956 |
+
)
|
| 957 |
+
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
|
| 958 |
+
boundary_zone = (dilated > 0) & (mask < 0.3)
|
| 959 |
+
|
| 960 |
+
if not np.any(boundary_zone):
|
| 961 |
+
return generated
|
| 962 |
+
|
| 963 |
+
# Convert to Lab for perceptual color matching
|
| 964 |
+
orig_lab = cv2.cvtColor(
|
| 965 |
+
original.astype(np.uint8), cv2.COLOR_RGB2LAB
|
| 966 |
+
).astype(np.float32)
|
| 967 |
+
gen_lab = cv2.cvtColor(
|
| 968 |
+
generated.astype(np.uint8), cv2.COLOR_RGB2LAB
|
| 969 |
+
).astype(np.float32)
|
| 970 |
+
|
| 971 |
+
# Calculate statistics in boundary zone (original)
|
| 972 |
+
boundary_orig_l = orig_lab[boundary_zone, 0]
|
| 973 |
+
boundary_orig_a = orig_lab[boundary_zone, 1]
|
| 974 |
+
boundary_orig_b = orig_lab[boundary_zone, 2]
|
| 975 |
+
|
| 976 |
+
orig_mean_l = np.median(boundary_orig_l)
|
| 977 |
+
orig_mean_a = np.median(boundary_orig_a)
|
| 978 |
+
orig_mean_b = np.median(boundary_orig_b)
|
| 979 |
+
|
| 980 |
+
# Calculate statistics in generated inpaint region
|
| 981 |
+
inpaint_zone = mask > 0.5
|
| 982 |
+
if not np.any(inpaint_zone):
|
| 983 |
+
return generated
|
| 984 |
+
|
| 985 |
+
gen_inpaint_l = gen_lab[inpaint_zone, 0]
|
| 986 |
+
gen_inpaint_a = gen_lab[inpaint_zone, 1]
|
| 987 |
+
gen_inpaint_b = gen_lab[inpaint_zone, 2]
|
| 988 |
+
|
| 989 |
+
gen_mean_l = np.median(gen_inpaint_l)
|
| 990 |
+
gen_mean_a = np.median(gen_inpaint_a)
|
| 991 |
+
gen_mean_b = np.median(gen_inpaint_b)
|
| 992 |
+
|
| 993 |
+
# Calculate correction deltas
|
| 994 |
+
delta_l = orig_mean_l - gen_mean_l
|
| 995 |
+
delta_a = orig_mean_a - gen_mean_a
|
| 996 |
+
delta_b = orig_mean_b - gen_mean_b
|
| 997 |
+
|
| 998 |
+
# Limit correction to avoid over-adjustment
|
| 999 |
+
max_correction = 15
|
| 1000 |
+
delta_l = np.clip(delta_l, -max_correction, max_correction)
|
| 1001 |
+
delta_a = np.clip(delta_a, -max_correction * 0.5, max_correction * 0.5)
|
| 1002 |
+
delta_b = np.clip(delta_b, -max_correction * 0.5, max_correction * 0.5)
|
| 1003 |
+
|
| 1004 |
+
logger.debug(f"Color correction deltas: L={delta_l:.1f}, a={delta_a:.1f}, b={delta_b:.1f}")
|
| 1005 |
+
|
| 1006 |
+
# Apply correction with spatial falloff from boundary
|
| 1007 |
+
# Create distance map from boundary
|
| 1008 |
+
distance = cv2.distanceTransform(
|
| 1009 |
+
mask_binary, cv2.DIST_L2, 5
|
| 1010 |
+
)
|
| 1011 |
+
max_dist = np.max(distance)
|
| 1012 |
+
if max_dist > 0:
|
| 1013 |
+
# Correction strength falls off from boundary toward center
|
| 1014 |
+
correction_strength = 1.0 - np.clip(distance / (max_dist * 0.5), 0, 1)
|
| 1015 |
+
else:
|
| 1016 |
+
correction_strength = np.ones_like(distance)
|
| 1017 |
+
|
| 1018 |
+
# Apply correction to Lab channels
|
| 1019 |
+
corrected_lab = gen_lab.copy()
|
| 1020 |
+
corrected_lab[:, :, 0] += delta_l * correction_strength * 0.7
|
| 1021 |
+
corrected_lab[:, :, 1] += delta_a * correction_strength * 0.5
|
| 1022 |
+
corrected_lab[:, :, 2] += delta_b * correction_strength * 0.5
|
| 1023 |
+
|
| 1024 |
+
# Clip to valid Lab ranges
|
| 1025 |
+
corrected_lab[:, :, 0] = np.clip(corrected_lab[:, :, 0], 0, 255)
|
| 1026 |
+
corrected_lab[:, :, 1] = np.clip(corrected_lab[:, :, 1], 0, 255)
|
| 1027 |
+
corrected_lab[:, :, 2] = np.clip(corrected_lab[:, :, 2], 0, 255)
|
| 1028 |
+
|
| 1029 |
+
# Convert back to RGB
|
| 1030 |
+
corrected_rgb = cv2.cvtColor(
|
| 1031 |
+
corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB
|
| 1032 |
+
).astype(np.float32)
|
| 1033 |
+
|
| 1034 |
+
return corrected_rgb
|
| 1035 |
+
|
| 1036 |
+
def blend_inpainting_with_guided_filter(
|
| 1037 |
+
self,
|
| 1038 |
+
original: Image.Image,
|
| 1039 |
+
generated: Image.Image,
|
| 1040 |
+
mask: Image.Image,
|
| 1041 |
+
feather_radius: int = 8,
|
| 1042 |
+
guide_radius: int = 8,
|
| 1043 |
+
guide_eps: float = 0.01
|
| 1044 |
+
) -> Image.Image:
|
| 1045 |
+
"""
|
| 1046 |
+
Blend inpainted region using guided filter for edge-aware transitions.
|
| 1047 |
+
|
| 1048 |
+
Combines standard alpha blending with guided filtering to preserve
|
| 1049 |
+
edges in the original image while seamlessly integrating new content.
|
| 1050 |
+
|
| 1051 |
+
Parameters
|
| 1052 |
+
----------
|
| 1053 |
+
original : PIL.Image
|
| 1054 |
+
Original image
|
| 1055 |
+
generated : PIL.Image
|
| 1056 |
+
Generated/inpainted result
|
| 1057 |
+
mask : PIL.Image
|
| 1058 |
+
Inpainting mask
|
| 1059 |
+
feather_radius : int
|
| 1060 |
+
Base feathering radius
|
| 1061 |
+
guide_radius : int
|
| 1062 |
+
Guided filter radius
|
| 1063 |
+
guide_eps : float
|
| 1064 |
+
Guided filter regularization
|
| 1065 |
+
|
| 1066 |
+
Returns
|
| 1067 |
+
-------
|
| 1068 |
+
PIL.Image
|
| 1069 |
+
Blended result with edge-aware transitions
|
| 1070 |
+
"""
|
| 1071 |
+
logger.info("Applying guided filter inpainting blend")
|
| 1072 |
+
|
| 1073 |
+
# Ensure same size
|
| 1074 |
+
if generated.size != original.size:
|
| 1075 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 1076 |
+
if mask.size != original.size:
|
| 1077 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 1078 |
+
|
| 1079 |
+
# Convert to arrays
|
| 1080 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 1081 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 1082 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 1083 |
+
|
| 1084 |
+
# Apply base feathering
|
| 1085 |
+
if feather_radius > 0:
|
| 1086 |
+
kernel_size = feather_radius * 2 + 1
|
| 1087 |
+
mask_feathered = cv2.GaussianBlur(
|
| 1088 |
+
mask_array,
|
| 1089 |
+
(kernel_size, kernel_size),
|
| 1090 |
+
feather_radius / 2
|
| 1091 |
+
)
|
| 1092 |
+
else:
|
| 1093 |
+
mask_feathered = mask_array
|
| 1094 |
+
|
| 1095 |
+
# Use original image as guide for the filter
|
| 1096 |
+
guide = cv2.cvtColor(orig_array.astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 1097 |
+
guide = guide.astype(np.float32) / 255.0
|
| 1098 |
+
|
| 1099 |
+
# Apply guided filter to the mask
|
| 1100 |
+
try:
|
| 1101 |
+
mask_guided = cv2.ximgproc.guidedFilter(
|
| 1102 |
+
guide=guide,
|
| 1103 |
+
src=mask_feathered,
|
| 1104 |
+
radius=guide_radius,
|
| 1105 |
+
eps=guide_eps
|
| 1106 |
+
)
|
| 1107 |
+
logger.debug("Guided filter applied successfully")
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
logger.warning(f"Guided filter failed: {e}, using standard feathering")
|
| 1110 |
+
mask_guided = mask_feathered
|
| 1111 |
+
|
| 1112 |
+
# Alpha blending
|
| 1113 |
+
alpha = mask_guided[:, :, np.newaxis]
|
| 1114 |
+
result = gen_array * alpha + orig_array * (1 - alpha)
|
| 1115 |
+
result = np.clip(result, 0, 255).astype(np.uint8)
|
| 1116 |
+
|
| 1117 |
+
return Image.fromarray(result)
|
mask_generator.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import traceback
|
| 4 |
+
from PIL import Image, ImageFilter, ImageDraw
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
| 8 |
+
import io
|
| 9 |
+
import gc
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoModelForImageSegmentation
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from rembg import remove, new_session
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
logger.setLevel(logging.INFO)
|
| 17 |
+
|
| 18 |
+
class MaskGenerator:
|
| 19 |
+
"""
|
| 20 |
+
Intelligent mask generation using deep learning models with traditional fallback.
|
| 21 |
+
Priority: BiRefNet > U²-Net (rembg) > Traditional gradient-based methods
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, max_image_size: int = 1024, device: str = "auto"):
|
| 25 |
+
self.max_image_size = max_image_size
|
| 26 |
+
self.device = self._setup_device(device)
|
| 27 |
+
|
| 28 |
+
# BiRefNet model (lazy loading)
|
| 29 |
+
self._birefnet_model = None
|
| 30 |
+
self._birefnet_transform = None
|
| 31 |
+
|
| 32 |
+
# Log initialization
|
| 33 |
+
logger.info(f"🎭 MaskGenerator initialized on {self.device}")
|
| 34 |
+
|
| 35 |
+
def _setup_device(self, device: str) -> str:
|
| 36 |
+
"""Setup computation device"""
|
| 37 |
+
if device == "auto":
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
return "cuda"
|
| 40 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 41 |
+
return "mps"
|
| 42 |
+
return "cpu"
|
| 43 |
+
return device
|
| 44 |
+
|
| 45 |
+
def _load_birefnet_model(self) -> bool:
|
| 46 |
+
"""
|
| 47 |
+
Lazy load BiRefNet model for memory efficiency.
|
| 48 |
+
Returns True if model loaded successfully, False otherwise.
|
| 49 |
+
"""
|
| 50 |
+
if self._birefnet_model is not None:
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
logger.info("📥 Loading BiRefNet model (ZhengPeng7/BiRefNet)...")
|
| 55 |
+
|
| 56 |
+
# Load model with fp16 for memory efficiency on GPU
|
| 57 |
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 58 |
+
|
| 59 |
+
self._birefnet_model = AutoModelForImageSegmentation.from_pretrained(
|
| 60 |
+
"ZhengPeng7/BiRefNet",
|
| 61 |
+
trust_remote_code=True,
|
| 62 |
+
torch_dtype=dtype
|
| 63 |
+
)
|
| 64 |
+
self._birefnet_model.to(self.device)
|
| 65 |
+
self._birefnet_model.eval()
|
| 66 |
+
|
| 67 |
+
# Define preprocessing transform
|
| 68 |
+
self._birefnet_transform = transforms.Compose([
|
| 69 |
+
transforms.Resize((1024, 1024)),
|
| 70 |
+
transforms.ToTensor(),
|
| 71 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
logger.info("✅ BiRefNet model loaded successfully")
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"❌ Failed to load BiRefNet: {e}")
|
| 79 |
+
self._birefnet_model = None
|
| 80 |
+
self._birefnet_transform = None
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
def _unload_birefnet_model(self):
|
| 84 |
+
"""Unload BiRefNet model to free memory"""
|
| 85 |
+
if self._birefnet_model is not None:
|
| 86 |
+
del self._birefnet_model
|
| 87 |
+
self._birefnet_model = None
|
| 88 |
+
self._birefnet_transform = None
|
| 89 |
+
|
| 90 |
+
if torch.cuda.is_available():
|
| 91 |
+
torch.cuda.empty_cache()
|
| 92 |
+
gc.collect()
|
| 93 |
+
logger.info("🧹 BiRefNet model unloaded")
|
| 94 |
+
|
| 95 |
+
def apply_guided_filter(
|
| 96 |
+
self,
|
| 97 |
+
mask: np.ndarray,
|
| 98 |
+
guide_image: Image.Image,
|
| 99 |
+
radius: int = 8,
|
| 100 |
+
eps: float = 0.01
|
| 101 |
+
) -> np.ndarray:
|
| 102 |
+
"""
|
| 103 |
+
Apply guided filter to mask for edge-preserving smoothing.
|
| 104 |
+
Falls back to Gaussian blur if guided filter is not available.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
mask: Input mask as numpy array (0-255)
|
| 108 |
+
guide_image: Original image to use as guide
|
| 109 |
+
radius: Filter radius (larger = more smoothing)
|
| 110 |
+
eps: Regularization parameter (smaller = more edge-preserving)
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Filtered mask as numpy array (0-255)
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
# Convert guide image to grayscale
|
| 117 |
+
guide_gray = np.array(guide_image.convert('L')).astype(np.float32) / 255.0
|
| 118 |
+
mask_float = mask.astype(np.float32) / 255.0
|
| 119 |
+
|
| 120 |
+
logger.info(f"🔧 Applying guided filter (radius={radius}, eps={eps})")
|
| 121 |
+
|
| 122 |
+
# Apply guided filter
|
| 123 |
+
filtered = cv2.ximgproc.guidedFilter(
|
| 124 |
+
guide=guide_gray,
|
| 125 |
+
src=mask_float,
|
| 126 |
+
radius=radius,
|
| 127 |
+
eps=eps
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Convert back to 0-255 range
|
| 131 |
+
result = (np.clip(filtered, 0, 1) * 255).astype(np.uint8)
|
| 132 |
+
logger.info("✅ Guided filter applied successfully")
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"❌ Guided filter failed: {e}, using original mask")
|
| 137 |
+
return mask
|
| 138 |
+
|
| 139 |
+
def try_birefnet_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
|
| 140 |
+
"""
|
| 141 |
+
Generate foreground mask using BiRefNet model.
|
| 142 |
+
BiRefNet provides high-quality segmentation with clean edges.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
original_image: Input PIL Image
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
PIL Image (L mode) mask or None if failed
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
# Lazy load model
|
| 152 |
+
if not self._load_birefnet_model():
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
logger.info("🤖 Starting BiRefNet foreground extraction...")
|
| 156 |
+
original_size = original_image.size
|
| 157 |
+
|
| 158 |
+
# Convert to RGB if needed
|
| 159 |
+
if original_image.mode != 'RGB':
|
| 160 |
+
image_rgb = original_image.convert('RGB')
|
| 161 |
+
else:
|
| 162 |
+
image_rgb = original_image
|
| 163 |
+
|
| 164 |
+
# Preprocess image
|
| 165 |
+
input_tensor = self._birefnet_transform(image_rgb).unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
# Move to device with appropriate dtype
|
| 168 |
+
if self.device == "cuda":
|
| 169 |
+
input_tensor = input_tensor.to(self.device, dtype=torch.float16)
|
| 170 |
+
else:
|
| 171 |
+
input_tensor = input_tensor.to(self.device)
|
| 172 |
+
|
| 173 |
+
# Run inference
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
outputs = self._birefnet_model(input_tensor)
|
| 176 |
+
|
| 177 |
+
# BiRefNet outputs a list, get the final prediction
|
| 178 |
+
if isinstance(outputs, (list, tuple)):
|
| 179 |
+
pred = outputs[-1]
|
| 180 |
+
else:
|
| 181 |
+
pred = outputs
|
| 182 |
+
|
| 183 |
+
# Sigmoid to get probability map
|
| 184 |
+
pred = torch.sigmoid(pred)
|
| 185 |
+
|
| 186 |
+
# Convert to numpy
|
| 187 |
+
pred_np = pred.squeeze().cpu().numpy()
|
| 188 |
+
|
| 189 |
+
# Convert to 0-255 range
|
| 190 |
+
mask_array = (pred_np * 255).astype(np.uint8)
|
| 191 |
+
|
| 192 |
+
# Resize back to original size
|
| 193 |
+
mask_pil = Image.fromarray(mask_array, mode='L')
|
| 194 |
+
mask_pil = mask_pil.resize(original_size, Image.LANCZOS)
|
| 195 |
+
mask_array = np.array(mask_pil)
|
| 196 |
+
|
| 197 |
+
# Quality check
|
| 198 |
+
mean_val = mask_array.mean()
|
| 199 |
+
nonzero_ratio = np.count_nonzero(mask_array > 50) / mask_array.size
|
| 200 |
+
|
| 201 |
+
logger.info(f"📊 BiRefNet mask stats - Mean: {mean_val:.1f}, Coverage: {nonzero_ratio:.1%}")
|
| 202 |
+
|
| 203 |
+
if mean_val < 10:
|
| 204 |
+
logger.warning("⚠️ BiRefNet mask too weak, falling back")
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
if nonzero_ratio < 0.03:
|
| 208 |
+
logger.warning("⚠️ BiRefNet foreground coverage too low, falling back")
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
# Light post-processing for edge refinement
|
| 212 |
+
# Use morphological operations to clean up
|
| 213 |
+
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 214 |
+
mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_CLOSE, kernel_small)
|
| 215 |
+
|
| 216 |
+
logger.info("✅ BiRefNet mask generation successful!")
|
| 217 |
+
return Image.fromarray(mask_array, mode='L')
|
| 218 |
+
|
| 219 |
+
except torch.cuda.OutOfMemoryError:
|
| 220 |
+
logger.error("❌ BiRefNet: GPU memory exhausted")
|
| 221 |
+
self._unload_birefnet_model()
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"❌ BiRefNet mask generation failed: {e}")
|
| 226 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
def try_deep_learning_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
|
| 230 |
+
"""
|
| 231 |
+
Intelligent foreground extraction with model priority:
|
| 232 |
+
1. BiRefNet (best quality, clean edges)
|
| 233 |
+
2. U²-Net via rembg (good fallback)
|
| 234 |
+
3. Return None to trigger traditional methods
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
original_image: Input PIL Image
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
PIL Image (L mode) mask or None if all methods failed
|
| 241 |
+
"""
|
| 242 |
+
# Priority 1: Try BiRefNet first
|
| 243 |
+
logger.info("🤖 Attempting BiRefNet mask generation...")
|
| 244 |
+
birefnet_mask = self.try_birefnet_mask(original_image)
|
| 245 |
+
if birefnet_mask is not None:
|
| 246 |
+
logger.info("✅ Using BiRefNet generated mask")
|
| 247 |
+
return birefnet_mask
|
| 248 |
+
|
| 249 |
+
# Priority 2: Fallback to rembg (U²-Net)
|
| 250 |
+
logger.info("🔄 BiRefNet unavailable/failed, trying rembg...")
|
| 251 |
+
try:
|
| 252 |
+
logger.info("🤖 Starting rembg foreground extraction")
|
| 253 |
+
|
| 254 |
+
# Try u2net first (better for cartoons/objects like Snoopy)
|
| 255 |
+
try:
|
| 256 |
+
session = new_session('u2net')
|
| 257 |
+
logger.info("✅ Using u2net model")
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.warning(f"u2net failed ({e}), trying u2net_human_seg")
|
| 260 |
+
try:
|
| 261 |
+
session = new_session('u2net_human_seg')
|
| 262 |
+
logger.info("✅ Using u2net_human_seg model")
|
| 263 |
+
except Exception as e2:
|
| 264 |
+
logger.error(f"All rembg models failed: {e2}")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
# Convert image to bytes for rembg
|
| 268 |
+
img_byte_arr = io.BytesIO()
|
| 269 |
+
original_image.save(img_byte_arr, format='PNG')
|
| 270 |
+
img_byte_arr = img_byte_arr.getvalue()
|
| 271 |
+
logger.info(f"📷 Image size: {len(img_byte_arr)} bytes")
|
| 272 |
+
|
| 273 |
+
# Perform background removal
|
| 274 |
+
result = remove(img_byte_arr, session=session)
|
| 275 |
+
result_img = Image.open(io.BytesIO(result)).convert('RGBA')
|
| 276 |
+
alpha_channel = result_img.split()[-1]
|
| 277 |
+
alpha_array = np.array(alpha_channel)
|
| 278 |
+
|
| 279 |
+
logger.info(f"📊 Raw alpha stats - Mean: {alpha_array.mean():.1f}, Min: {alpha_array.min()}, Max: {alpha_array.max()}")
|
| 280 |
+
|
| 281 |
+
# Step 1: Light smoothing to reduce noise but preserve edges
|
| 282 |
+
alpha_smoothed = cv2.GaussianBlur(alpha_array, (3, 3), 0.8)
|
| 283 |
+
|
| 284 |
+
# Step 2: Contrast stretching to utilize full range
|
| 285 |
+
alpha_stretched = cv2.normalize(alpha_smoothed, None, 0, 255, cv2.NORM_MINMAX)
|
| 286 |
+
|
| 287 |
+
# Step 3: CRITICAL FIX - More aggressive foreground preservation
|
| 288 |
+
# Instead of hard threshold, use adaptive approach
|
| 289 |
+
|
| 290 |
+
# Find the main subject area (high confidence regions)
|
| 291 |
+
high_confidence = alpha_stretched > 180
|
| 292 |
+
medium_confidence = (alpha_stretched > 60) & (alpha_stretched <= 180)
|
| 293 |
+
low_confidence = (alpha_stretched > 15) & (alpha_stretched <= 60)
|
| 294 |
+
|
| 295 |
+
# Create final mask with better extremity handling
|
| 296 |
+
final_alpha = np.zeros_like(alpha_stretched)
|
| 297 |
+
|
| 298 |
+
# High confidence areas - keep at full opacity
|
| 299 |
+
final_alpha[high_confidence] = 255
|
| 300 |
+
|
| 301 |
+
# Medium confidence - boost significantly
|
| 302 |
+
final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
|
| 303 |
+
|
| 304 |
+
# Low confidence - moderate boost (catches faint extremities)
|
| 305 |
+
final_alpha[low_confidence] = np.clip(alpha_stretched[low_confidence] * 2.5, 120, 199)
|
| 306 |
+
|
| 307 |
+
# Morphological operations to connect disconnected parts (hands, feet, tail)
|
| 308 |
+
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 309 |
+
kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 310 |
+
|
| 311 |
+
# Close small gaps (helps connect separated body parts)
|
| 312 |
+
final_alpha = cv2.morphologyEx(final_alpha, cv2.MORPH_CLOSE, kernel_small, iterations=1)
|
| 313 |
+
|
| 314 |
+
# Light dilation to ensure nothing gets cut off
|
| 315 |
+
final_alpha = cv2.dilate(final_alpha, kernel_small, iterations=1)
|
| 316 |
+
|
| 317 |
+
logger.info(f"📊 Final alpha stats - Mean: {final_alpha.mean():.1f}, Min: {final_alpha.min()}, Max: {final_alpha.max()}")
|
| 318 |
+
|
| 319 |
+
# Quality check - but be more lenient for cartoon characters
|
| 320 |
+
if final_alpha.mean() < 10:
|
| 321 |
+
logger.warning("⚠️ Alpha still too weak, falling back to traditional method")
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
# Enhanced post-processing for cartoon characters
|
| 325 |
+
is_cartoon = self._detect_cartoon_character(original_image, final_alpha)
|
| 326 |
+
|
| 327 |
+
if is_cartoon:
|
| 328 |
+
logger.info("🎭 Detected cartoon/character image, applying specialized processing")
|
| 329 |
+
final_alpha = self._enhance_cartoon_mask(original_image, final_alpha)
|
| 330 |
+
|
| 331 |
+
# Count non-zero pixels to ensure we have substantial foreground
|
| 332 |
+
foreground_pixels = np.count_nonzero(final_alpha > 50)
|
| 333 |
+
total_pixels = final_alpha.size
|
| 334 |
+
foreground_ratio = foreground_pixels / total_pixels
|
| 335 |
+
logger.info(f"📊 Foreground coverage: {foreground_ratio:.1%} of image")
|
| 336 |
+
|
| 337 |
+
if foreground_ratio < 0.05: # Less than 5% is probably too little
|
| 338 |
+
logger.warning("⚠️ Very low foreground coverage, falling back to traditional method")
|
| 339 |
+
return None
|
| 340 |
+
|
| 341 |
+
mask = Image.fromarray(final_alpha.astype(np.uint8), mode='L')
|
| 342 |
+
logger.info("✅ Enhanced rembg mask generation successful!")
|
| 343 |
+
return mask
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"❌ Deep learning mask extraction failed: {e}")
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
def _detect_cartoon_character(self, original_image: Image.Image, alpha_mask: np.ndarray) -> bool:
|
| 350 |
+
"""
|
| 351 |
+
Detect if image is cartoon/line art (heuristic approach)
|
| 352 |
+
"""
|
| 353 |
+
try:
|
| 354 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 355 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 356 |
+
|
| 357 |
+
# Calculate edge density (cartoons usually have more clear edges)
|
| 358 |
+
edges = cv2.Canny(gray, 50, 150)
|
| 359 |
+
edge_density = np.count_nonzero(edges) / max(edges.size, 1) # Avoid division by zero
|
| 360 |
+
|
| 361 |
+
# Calculate color complexity (cartoons usually have fewer colors) - optimize memory usage
|
| 362 |
+
h, w, c = img_array.shape
|
| 363 |
+
if h * w > 100000: # If image is too large, resize for processing
|
| 364 |
+
small_img = cv2.resize(img_array, (200, 200))
|
| 365 |
+
else:
|
| 366 |
+
small_img = img_array
|
| 367 |
+
|
| 368 |
+
unique_colors = len(np.unique(small_img.reshape(-1, 3), axis=0))
|
| 369 |
+
total_pixels = small_img.shape[0] * small_img.shape[1]
|
| 370 |
+
color_simplicity = unique_colors < (total_pixels * 0.1)
|
| 371 |
+
|
| 372 |
+
# Check for obvious black outlines
|
| 373 |
+
dark_pixels_ratio = np.count_nonzero(gray < 50) / max(gray.size, 1) # Avoid division by zero
|
| 374 |
+
has_black_outline = dark_pixels_ratio > 0.05
|
| 375 |
+
|
| 376 |
+
# Comprehensive judgment: high edge density + color simplicity + black outline = likely cartoon
|
| 377 |
+
is_cartoon = (edge_density > 0.05) and (color_simplicity or has_black_outline)
|
| 378 |
+
|
| 379 |
+
logger.info(f"🔍 Cartoon detection - Edge density: {edge_density:.3f}, Color simplicity: {color_simplicity}, Black outline: {has_black_outline} -> Cartoon: {is_cartoon}")
|
| 380 |
+
return is_cartoon
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logger.error(f"❌ Cartoon detection failed: {e}")
|
| 384 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 385 |
+
print(f"❌ CARTOON DETECTION ERROR: {e}")
|
| 386 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 387 |
+
return False
|
| 388 |
+
|
| 389 |
+
def _enhance_cartoon_mask(self, original_image: Image.Image, alpha_mask: np.ndarray) -> np.ndarray:
|
| 390 |
+
"""
|
| 391 |
+
Enhanced mask processing for cartoon characters
|
| 392 |
+
"""
|
| 393 |
+
try:
|
| 394 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 395 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 396 |
+
enhanced_alpha = alpha_mask.copy()
|
| 397 |
+
|
| 398 |
+
# Step 1: Black outline enhancement - find black outlines and enhance their alpha
|
| 399 |
+
th_dark = 80 # Adjustable parameter: black threshold
|
| 400 |
+
black_outline = gray < th_dark
|
| 401 |
+
|
| 402 |
+
# Dilate black outline region by 1px
|
| 403 |
+
kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Adjustable parameter: dilation kernel size
|
| 404 |
+
black_outline_dilated = cv2.dilate(black_outline.astype(np.uint8), kernel_dilate, iterations=1)
|
| 405 |
+
|
| 406 |
+
# Set black outline region alpha directly to 255
|
| 407 |
+
enhanced_alpha[black_outline_dilated > 0] = 255
|
| 408 |
+
logger.info(f"🖤 Black outline enhanced: {np.count_nonzero(black_outline_dilated)} pixels")
|
| 409 |
+
|
| 410 |
+
# Step 2: Simplified internal enhancement - process white fill areas within outlines
|
| 411 |
+
# Find high confidence regions (alpha ≥ 160)
|
| 412 |
+
high_confidence = enhanced_alpha >= 160
|
| 413 |
+
|
| 414 |
+
# Apply close operation on high confidence regions to connect separated parts
|
| 415 |
+
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Adjustable parameter: close kernel size
|
| 416 |
+
high_confidence_closed = cv2.morphologyEx(high_confidence.astype(np.uint8), cv2.MORPH_CLOSE, kernel_close, iterations=1)
|
| 417 |
+
|
| 418 |
+
# Simplified approach: directly enhance medium confidence regions without complex flood fill
|
| 419 |
+
# Find medium/low confidence regions surrounded by high confidence regions
|
| 420 |
+
medium_confidence = (enhanced_alpha >= 80) & (enhanced_alpha < 160)
|
| 421 |
+
|
| 422 |
+
# Dilate high confidence region to include more internal areas
|
| 423 |
+
kernel_dilate_internal = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 424 |
+
high_confidence_expanded = cv2.dilate(high_confidence_closed, kernel_dilate_internal, iterations=1)
|
| 425 |
+
|
| 426 |
+
# Medium confidence pixels within expanded high confidence areas are considered internal fill
|
| 427 |
+
internal_fill_regions = medium_confidence & (high_confidence_expanded > 0)
|
| 428 |
+
|
| 429 |
+
# Enhance alpha of these internal fill regions to at least 220
|
| 430 |
+
min_alpha_for_fill = 220 # Adjustable parameter: minimum alpha for internal fill
|
| 431 |
+
enhanced_alpha[internal_fill_regions] = np.maximum(enhanced_alpha[internal_fill_regions], min_alpha_for_fill)
|
| 432 |
+
|
| 433 |
+
logger.info(f"🤍 Internal fill regions enhanced: {np.count_nonzero(internal_fill_regions)} pixels")
|
| 434 |
+
logger.info(f"📊 Enhanced alpha stats - Mean: {enhanced_alpha.mean():.1f}, Min: {enhanced_alpha.min()}, Max: {enhanced_alpha.max()}")
|
| 435 |
+
|
| 436 |
+
return enhanced_alpha
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.error(f"❌ Cartoon mask enhancement failed: {e}")
|
| 440 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 441 |
+
print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
|
| 442 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 443 |
+
return alpha_mask
|
| 444 |
+
|
| 445 |
+
def _adjust_mask_for_scene_focus(self, mask: Image.Image, original_image: Image.Image) -> Image.Image:
|
| 446 |
+
"""
|
| 447 |
+
Adjust mask for scene focus mode to include nearby objects like chairs, furniture
|
| 448 |
+
"""
|
| 449 |
+
try:
|
| 450 |
+
logger.info("🏠 Adjusting mask for scene focus mode...")
|
| 451 |
+
|
| 452 |
+
mask_array = np.array(mask)
|
| 453 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 454 |
+
|
| 455 |
+
# Expand mask to include nearby objects
|
| 456 |
+
# Use larger dilation kernel to include furniture/objects
|
| 457 |
+
kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 458 |
+
expanded_mask = cv2.dilate(mask_array, kernel_large, iterations=2)
|
| 459 |
+
|
| 460 |
+
# Find contours in the expanded area to detect objects
|
| 461 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 462 |
+
edges = cv2.Canny(gray, 30, 100)
|
| 463 |
+
|
| 464 |
+
# Apply edge detection only in the expanded region
|
| 465 |
+
expanded_region = (expanded_mask > 0) & (mask_array == 0)
|
| 466 |
+
object_edges = np.zeros_like(edges)
|
| 467 |
+
object_edges[expanded_region] = edges[expanded_region]
|
| 468 |
+
|
| 469 |
+
# Close gaps to form complete objects
|
| 470 |
+
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 471 |
+
object_mask = cv2.morphologyEx(object_edges, cv2.MORPH_CLOSE, kernel_close)
|
| 472 |
+
object_mask = cv2.dilate(object_mask, kernel_close, iterations=1)
|
| 473 |
+
|
| 474 |
+
# Combine with original mask
|
| 475 |
+
final_mask = np.maximum(mask_array, object_mask)
|
| 476 |
+
|
| 477 |
+
logger.info("✅ Scene focus adjustment completed")
|
| 478 |
+
return Image.fromarray(final_mask)
|
| 479 |
+
|
| 480 |
+
except Exception as e:
|
| 481 |
+
logger.error(f"❌ Scene focus adjustment failed: {e}")
|
| 482 |
+
return mask
|
| 483 |
+
|
| 484 |
+
def create_gradient_based_mask(self, original_image: Image.Image, mode: str = "center", focus_mode: str = "person") -> Image.Image:
|
| 485 |
+
"""
|
| 486 |
+
Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
|
| 487 |
+
Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
|
| 488 |
+
"""
|
| 489 |
+
width, height = original_image.size
|
| 490 |
+
logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
|
| 491 |
+
|
| 492 |
+
if mode == "center":
|
| 493 |
+
# Try using deep learning models for intelligent foreground extraction
|
| 494 |
+
logger.info("🤖 Attempting deep learning mask generation...")
|
| 495 |
+
dl_mask = self.try_deep_learning_mask(original_image)
|
| 496 |
+
if dl_mask is not None:
|
| 497 |
+
logger.info("✅ Using deep learning generated mask")
|
| 498 |
+
# Apply focus mode adjustments to deep learning mask
|
| 499 |
+
if focus_mode == "scene":
|
| 500 |
+
dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
|
| 501 |
+
return dl_mask
|
| 502 |
+
|
| 503 |
+
# Fallback to traditional method
|
| 504 |
+
logger.info("🔄 Deep learning failed, using traditional gradient-based method")
|
| 505 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 506 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 507 |
+
|
| 508 |
+
# First-order derivatives: use Sobel operator for edge detection
|
| 509 |
+
grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 510 |
+
grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 511 |
+
gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
|
| 512 |
+
|
| 513 |
+
# Second-order derivatives: use Laplacian operator for texture change detection
|
| 514 |
+
laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=3)
|
| 515 |
+
laplacian_abs = np.abs(laplacian)
|
| 516 |
+
|
| 517 |
+
# Combine first and second order derivatives
|
| 518 |
+
combined_edges = gradient_magnitude * 0.7 + laplacian_abs * 0.3
|
| 519 |
+
combined_edges = (combined_edges / np.max(combined_edges)) * 255
|
| 520 |
+
|
| 521 |
+
# Threshold processing to find strong edges
|
| 522 |
+
_, edge_binary = cv2.threshold(combined_edges.astype(np.uint8), 20, 255, cv2.THRESH_BINARY)
|
| 523 |
+
|
| 524 |
+
# Morphological operations to connect edges
|
| 525 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 526 |
+
edge_binary = cv2.morphologyEx(edge_binary, cv2.MORPH_CLOSE, kernel)
|
| 527 |
+
|
| 528 |
+
# Find contours and create mask
|
| 529 |
+
contours, _ = cv2.findContours(edge_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 530 |
+
|
| 531 |
+
if contours:
|
| 532 |
+
# Find largest contour (main subject)
|
| 533 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 534 |
+
contour_mask = np.zeros((height, width), dtype=np.uint8)
|
| 535 |
+
cv2.fillPoly(contour_mask, [largest_contour], 255)
|
| 536 |
+
|
| 537 |
+
# Create foreground enhancement mask: specially protect dark regions
|
| 538 |
+
dark_mask = (gray < 90).astype(np.uint8) * 255
|
| 539 |
+
morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 540 |
+
dark_mask = cv2.morphologyEx(dark_mask, cv2.MORPH_CLOSE, morph_kernel, iterations=1)
|
| 541 |
+
dark_mask = cv2.dilate(dark_mask, morph_kernel, iterations=2)
|
| 542 |
+
contour_mask = cv2.bitwise_or(contour_mask, dark_mask)
|
| 543 |
+
|
| 544 |
+
# Get core foreground: clean holes and fill gaps
|
| 545 |
+
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 546 |
+
core_mask = cv2.morphologyEx(contour_mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
|
| 547 |
+
|
| 548 |
+
open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 549 |
+
core_mask = cv2.morphologyEx(core_mask, cv2.MORPH_OPEN, open_kernel, iterations=1)
|
| 550 |
+
|
| 551 |
+
# Convert to binary core (0/255)
|
| 552 |
+
_, core_binary = cv2.threshold(core_mask, 127, 255, cv2.THRESH_BINARY)
|
| 553 |
+
|
| 554 |
+
# Keep only slight dilation to avoid foreground being eaten
|
| 555 |
+
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 556 |
+
core_binary = cv2.dilate(core_binary, dilate_kernel, iterations=1)
|
| 557 |
+
|
| 558 |
+
# Distance transform feathering: shrink feathering range for sharp edges
|
| 559 |
+
FEATHER_PX = 4
|
| 560 |
+
|
| 561 |
+
# Calculate distance transform
|
| 562 |
+
core_float = core_binary.astype(np.float32) / 255.0
|
| 563 |
+
distances = cv2.distanceTransform((1 - core_float).astype(np.uint8), cv2.DIST_L2, 5)
|
| 564 |
+
|
| 565 |
+
# Create feathering mask: 0→FEATHER_PX linear mapping to 1→0
|
| 566 |
+
feather_mask = np.ones_like(distances)
|
| 567 |
+
edge_region = (distances > 0) & (distances <= FEATHER_PX)
|
| 568 |
+
feather_mask[edge_region] = 1.0 - (distances[edge_region] / FEATHER_PX)
|
| 569 |
+
feather_mask[distances > FEATHER_PX] = 0.0
|
| 570 |
+
|
| 571 |
+
# Apply double-smoothstep curve: make transition steeper, reduce semi-transparent halos
|
| 572 |
+
def double_smoothstep(t):
|
| 573 |
+
t = np.clip(t, 0, 1)
|
| 574 |
+
s1 = t * t * (3 - 2 * t)
|
| 575 |
+
return s1 * s1 * (3 - 2 * s1) # Equivalent to t^3 (10 - 15t + 6t^2)
|
| 576 |
+
|
| 577 |
+
# Combine core with feathering: core area keeps 255, edges use double_smoothstep feathering
|
| 578 |
+
final_alpha = np.zeros_like(distances)
|
| 579 |
+
final_alpha[core_binary > 127] = 1.0 # Core area
|
| 580 |
+
final_alpha[edge_region] = double_smoothstep(feather_mask[edge_region]) # Feathering area
|
| 581 |
+
|
| 582 |
+
# Convert to 0-255 range
|
| 583 |
+
final_mask = (final_alpha * 255).astype(np.uint8)
|
| 584 |
+
|
| 585 |
+
# Apply guided filter for edge-preserving smoothing
|
| 586 |
+
final_mask = self.apply_guided_filter(final_mask, original_image, radius=8, eps=0.01)
|
| 587 |
+
|
| 588 |
+
mask = Image.fromarray(final_mask)
|
| 589 |
+
else:
|
| 590 |
+
# Backup plan: use large ellipse
|
| 591 |
+
mask = Image.new('L', (width, height), 0)
|
| 592 |
+
draw = ImageDraw.Draw(mask)
|
| 593 |
+
center_x, center_y = width // 2, height // 2
|
| 594 |
+
width_radius = int(width * 0.45)
|
| 595 |
+
height_radius = int(width * 0.48)
|
| 596 |
+
draw.ellipse([
|
| 597 |
+
center_x - width_radius, center_y - height_radius,
|
| 598 |
+
center_x + width_radius, center_y + height_radius
|
| 599 |
+
], fill=255)
|
| 600 |
+
# Apply guided filter instead of Gaussian blur
|
| 601 |
+
mask_array = np.array(mask)
|
| 602 |
+
mask_array = self.apply_guided_filter(mask_array, original_image, radius=10, eps=0.02)
|
| 603 |
+
mask = Image.fromarray(mask_array)
|
| 604 |
+
|
| 605 |
+
elif mode == "left_half":
|
| 606 |
+
# Keep original logic unchanged - ensure Snoopy and other functions work normally
|
| 607 |
+
mask = Image.new('L', (width, height), 0)
|
| 608 |
+
mask_array = np.array(mask)
|
| 609 |
+
mask_array[:, :width//2] = 255
|
| 610 |
+
|
| 611 |
+
transition_zone = width // 10
|
| 612 |
+
for i in range(transition_zone):
|
| 613 |
+
x_pos = width//2 + i
|
| 614 |
+
if x_pos < width:
|
| 615 |
+
alpha = 255 * (1 - i / transition_zone)
|
| 616 |
+
mask_array[:, x_pos] = int(alpha)
|
| 617 |
+
|
| 618 |
+
mask = Image.fromarray(mask_array)
|
| 619 |
+
|
| 620 |
+
elif mode == "right_half":
|
| 621 |
+
# Keep original logic unchanged - ensure Snoopy and other functions work normally
|
| 622 |
+
mask = Image.new('L', (width, height), 0)
|
| 623 |
+
mask_array = np.array(mask)
|
| 624 |
+
mask_array[:, width//2:] = 255
|
| 625 |
+
|
| 626 |
+
transition_zone = width // 10
|
| 627 |
+
for i in range(transition_zone):
|
| 628 |
+
x_pos = width//2 - i - 1
|
| 629 |
+
if x_pos >= 0:
|
| 630 |
+
alpha = 255 * (1 - i / transition_zone)
|
| 631 |
+
mask_array[:, x_pos] = int(alpha)
|
| 632 |
+
|
| 633 |
+
mask = Image.fromarray(mask_array)
|
| 634 |
+
|
| 635 |
+
elif mode == "full":
|
| 636 |
+
mask = Image.new('L', (width, height), 0)
|
| 637 |
+
draw = ImageDraw.Draw(mask)
|
| 638 |
+
center_x, center_y = width // 2, height // 2
|
| 639 |
+
radius = min(width, height) // 8
|
| 640 |
+
|
| 641 |
+
draw.ellipse([
|
| 642 |
+
center_x - radius, center_y - radius,
|
| 643 |
+
center_x + radius, center_y + radius
|
| 644 |
+
], fill=255)
|
| 645 |
+
|
| 646 |
+
mask = mask.filter(ImageFilter.GaussianBlur(radius=5))
|
| 647 |
+
|
| 648 |
+
return mask
|
requirements.txt
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
|
| 2 |
git+https://github.com/linoytsaban/diffusers.git@wan22-loras
|
| 3 |
-
|
| 4 |
gradio
|
| 5 |
-
transformers
|
| 6 |
-
accelerate
|
| 7 |
safetensors
|
| 8 |
sentencepiece
|
| 9 |
peft
|
|
@@ -12,4 +11,15 @@ imageio-ffmpeg
|
|
| 12 |
opencv-python
|
| 13 |
pillow
|
| 14 |
spaces
|
| 15 |
-
torchao
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VividFlow I2V Dependencies
|
| 2 |
git+https://github.com/linoytsaban/diffusers.git@wan22-loras
|
|
|
|
| 3 |
gradio
|
| 4 |
+
transformers>=4.46.0
|
| 5 |
+
accelerate>=1.1.1
|
| 6 |
safetensors
|
| 7 |
sentencepiece
|
| 8 |
peft
|
|
|
|
| 11 |
opencv-python
|
| 12 |
pillow
|
| 13 |
spaces
|
| 14 |
+
torchao
|
| 15 |
+
|
| 16 |
+
# Background Generation Dependencies (SceneWeaver)
|
| 17 |
+
open_clip_torch
|
| 18 |
+
sentence-transformers
|
| 19 |
+
rembg[gpu]
|
| 20 |
+
scipy
|
| 21 |
+
opencv-contrib-python
|
| 22 |
+
|
| 23 |
+
# Core Dependencies
|
| 24 |
+
torch>=2.5.0
|
| 25 |
+
numpy
|
scene_templates.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class SceneTemplate:
|
| 9 |
+
"""Data class representing a scene template"""
|
| 10 |
+
key: str
|
| 11 |
+
name: str
|
| 12 |
+
prompt: str
|
| 13 |
+
negative_extra: str
|
| 14 |
+
category: str
|
| 15 |
+
icon: str
|
| 16 |
+
guidance_scale: float = 7.5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SceneTemplateManager:
|
| 20 |
+
"""
|
| 21 |
+
Manages curated scene templates for background generation.
|
| 22 |
+
Provides categorized presets that users can select with one click.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Scene template definitions
|
| 26 |
+
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 27 |
+
# Professional Category
|
| 28 |
+
"office_modern": SceneTemplate(
|
| 29 |
+
key="office_modern",
|
| 30 |
+
name="Modern Office",
|
| 31 |
+
prompt="modern minimalist office interior, clean white desk, large floor-to-ceiling windows, natural daylight, professional corporate environment, soft shadows, contemporary furniture",
|
| 32 |
+
negative_extra="messy, cluttered, dark, old",
|
| 33 |
+
category="Professional",
|
| 34 |
+
icon="🏢",
|
| 35 |
+
guidance_scale=7.5
|
| 36 |
+
),
|
| 37 |
+
"office_executive": SceneTemplate(
|
| 38 |
+
key="office_executive",
|
| 39 |
+
name="Executive Suite",
|
| 40 |
+
prompt="luxurious executive office, mahogany desk, leather chair, city skyline view through windows, warm ambient lighting, bookshelf, elegant professional setting",
|
| 41 |
+
negative_extra="cheap, cramped, messy",
|
| 42 |
+
category="Professional",
|
| 43 |
+
icon="👔",
|
| 44 |
+
guidance_scale=7.5
|
| 45 |
+
),
|
| 46 |
+
"studio_white": SceneTemplate(
|
| 47 |
+
key="studio_white",
|
| 48 |
+
name="White Studio",
|
| 49 |
+
prompt="clean white photography studio background, professional lighting setup, seamless white backdrop, soft diffused light, minimal shadows",
|
| 50 |
+
negative_extra="colored, textured, dirty",
|
| 51 |
+
category="Professional",
|
| 52 |
+
icon="📷",
|
| 53 |
+
guidance_scale=8.0
|
| 54 |
+
),
|
| 55 |
+
"coworking": SceneTemplate(
|
| 56 |
+
key="coworking",
|
| 57 |
+
name="Coworking Space",
|
| 58 |
+
prompt="modern coworking space, open plan office, plants, exposed brick, industrial chic design, natural light, collaborative environment",
|
| 59 |
+
negative_extra="empty, dark, boring",
|
| 60 |
+
category="Professional",
|
| 61 |
+
icon="💼",
|
| 62 |
+
guidance_scale=7.0
|
| 63 |
+
),
|
| 64 |
+
"conference": SceneTemplate(
|
| 65 |
+
key="conference",
|
| 66 |
+
name="Conference Room",
|
| 67 |
+
prompt="modern conference room, large meeting table, glass walls, professional presentation screen, bright corporate lighting, clean minimal design",
|
| 68 |
+
negative_extra="small, cramped, outdated",
|
| 69 |
+
category="Professional",
|
| 70 |
+
icon="🤝",
|
| 71 |
+
guidance_scale=7.5
|
| 72 |
+
),
|
| 73 |
+
|
| 74 |
+
# Nature Category
|
| 75 |
+
"beach_sunset": SceneTemplate(
|
| 76 |
+
key="beach_sunset",
|
| 77 |
+
name="Sunset Beach",
|
| 78 |
+
prompt="beautiful tropical beach at golden hour sunset, palm trees silhouette, calm turquoise ocean waves, warm orange and pink sky, soft sand, paradise vacation vibes",
|
| 79 |
+
negative_extra="storm, rain, crowded, trash",
|
| 80 |
+
category="Nature",
|
| 81 |
+
icon="🏖️",
|
| 82 |
+
guidance_scale=7.0
|
| 83 |
+
),
|
| 84 |
+
"forest_enchanted": SceneTemplate(
|
| 85 |
+
key="forest_enchanted",
|
| 86 |
+
name="Enchanted Forest",
|
| 87 |
+
prompt="magical enchanted forest, sunlight streaming through tall trees, lush green foliage, mystical atmosphere, morning mist, fairy tale woodland",
|
| 88 |
+
negative_extra="dead trees, dark, scary, barren",
|
| 89 |
+
category="Nature",
|
| 90 |
+
icon="🌲",
|
| 91 |
+
guidance_scale=7.0
|
| 92 |
+
),
|
| 93 |
+
"mountain_scenic": SceneTemplate(
|
| 94 |
+
key="mountain_scenic",
|
| 95 |
+
name="Mountain Vista",
|
| 96 |
+
prompt="breathtaking mountain landscape, snow-capped peaks, alpine meadow, clear blue sky, majestic scenic view, pristine nature, peaceful atmosphere",
|
| 97 |
+
negative_extra="industrial, polluted, crowded",
|
| 98 |
+
category="Nature",
|
| 99 |
+
icon="🏔️",
|
| 100 |
+
guidance_scale=7.5
|
| 101 |
+
),
|
| 102 |
+
"garden_spring": SceneTemplate(
|
| 103 |
+
key="garden_spring",
|
| 104 |
+
name="Spring Garden",
|
| 105 |
+
prompt="beautiful spring flower garden, colorful blooming flowers, roses and tulips, manicured hedges, sunny day, botanical paradise, fresh and vibrant",
|
| 106 |
+
negative_extra="dead, winter, wilted, dry",
|
| 107 |
+
category="Nature",
|
| 108 |
+
icon="🌸",
|
| 109 |
+
guidance_scale=7.0
|
| 110 |
+
),
|
| 111 |
+
"lake_serene": SceneTemplate(
|
| 112 |
+
key="lake_serene",
|
| 113 |
+
name="Serene Lake",
|
| 114 |
+
prompt="peaceful serene lake at dawn, mirror-like water reflection, surrounding mountains, soft morning light, tranquil atmosphere, pristine natural beauty",
|
| 115 |
+
negative_extra="stormy, polluted, industrial",
|
| 116 |
+
category="Nature",
|
| 117 |
+
icon="🏞️",
|
| 118 |
+
guidance_scale=7.0
|
| 119 |
+
),
|
| 120 |
+
"cherry_blossom": SceneTemplate(
|
| 121 |
+
key="cherry_blossom",
|
| 122 |
+
name="Cherry Blossom",
|
| 123 |
+
prompt="stunning cherry blossom trees in full bloom, pink sakura petals falling gently, Japanese garden aesthetic, soft spring sunlight, romantic atmosphere",
|
| 124 |
+
negative_extra="winter, dead, brown, wilted",
|
| 125 |
+
category="Nature",
|
| 126 |
+
icon="🌸",
|
| 127 |
+
guidance_scale=7.0
|
| 128 |
+
),
|
| 129 |
+
|
| 130 |
+
# Urban Category
|
| 131 |
+
"city_skyline": SceneTemplate(
|
| 132 |
+
key="city_skyline",
|
| 133 |
+
name="City Skyline",
|
| 134 |
+
prompt="modern city skyline at blue hour, impressive skyscrapers, glass buildings reflecting sunset, urban metropolitan view, cinematic atmosphere",
|
| 135 |
+
negative_extra="slums, dirty, abandoned, ruins",
|
| 136 |
+
category="Urban",
|
| 137 |
+
icon="🌆",
|
| 138 |
+
guidance_scale=7.5
|
| 139 |
+
),
|
| 140 |
+
"cafe_cozy": SceneTemplate(
|
| 141 |
+
key="cafe_cozy",
|
| 142 |
+
name="Cozy Cafe",
|
| 143 |
+
prompt="warm cozy coffee shop interior, wooden furniture, ambient lighting, exposed brick walls, plants, comfortable atmosphere, artisan cafe vibes",
|
| 144 |
+
negative_extra="fast food, plastic, harsh lighting",
|
| 145 |
+
category="Urban",
|
| 146 |
+
icon="☕",
|
| 147 |
+
guidance_scale=7.0
|
| 148 |
+
),
|
| 149 |
+
"street_european": SceneTemplate(
|
| 150 |
+
key="street_european",
|
| 151 |
+
name="European Street",
|
| 152 |
+
prompt="charming European cobblestone street, historic buildings, outdoor cafe, flowers on balconies, warm afternoon light, romantic Paris or Rome vibes",
|
| 153 |
+
negative_extra="modern, industrial, ugly, dirty",
|
| 154 |
+
category="Urban",
|
| 155 |
+
icon="🏛️",
|
| 156 |
+
guidance_scale=7.0
|
| 157 |
+
),
|
| 158 |
+
"night_neon": SceneTemplate(
|
| 159 |
+
key="night_neon",
|
| 160 |
+
name="Neon Nightlife",
|
| 161 |
+
prompt="vibrant city nightlife scene, neon lights and signs, urban night atmosphere, colorful reflections on wet street, cyberpunk aesthetic, electric energy",
|
| 162 |
+
negative_extra="daytime, boring, plain",
|
| 163 |
+
category="Urban",
|
| 164 |
+
icon="🌃",
|
| 165 |
+
guidance_scale=6.5
|
| 166 |
+
),
|
| 167 |
+
"rooftop_view": SceneTemplate(
|
| 168 |
+
key="rooftop_view",
|
| 169 |
+
name="Rooftop Terrace",
|
| 170 |
+
prompt="luxury rooftop terrace, city panoramic view, modern outdoor furniture, string lights, sunset golden hour, sophisticated urban oasis",
|
| 171 |
+
negative_extra="cheap, dirty, crowded",
|
| 172 |
+
category="Urban",
|
| 173 |
+
icon="🏙️",
|
| 174 |
+
guidance_scale=7.5
|
| 175 |
+
),
|
| 176 |
+
|
| 177 |
+
# Artistic Category
|
| 178 |
+
"gradient_soft": SceneTemplate(
|
| 179 |
+
key="gradient_soft",
|
| 180 |
+
name="Soft Gradient",
|
| 181 |
+
prompt="smooth soft gradient background, pastel colors blending beautifully, pink to blue to purple transition, dreamy aesthetic, professional portrait backdrop",
|
| 182 |
+
negative_extra="harsh, noisy, textured, busy",
|
| 183 |
+
category="Artistic",
|
| 184 |
+
icon="🎨",
|
| 185 |
+
guidance_scale=8.0
|
| 186 |
+
),
|
| 187 |
+
"abstract_modern": SceneTemplate(
|
| 188 |
+
key="abstract_modern",
|
| 189 |
+
name="Modern Abstract",
|
| 190 |
+
prompt="modern abstract art background, geometric shapes, bold colors, contemporary design, artistic composition, museum gallery aesthetic",
|
| 191 |
+
negative_extra="realistic, plain, boring",
|
| 192 |
+
category="Artistic",
|
| 193 |
+
icon="🖼️",
|
| 194 |
+
guidance_scale=6.5
|
| 195 |
+
),
|
| 196 |
+
"vintage_retro": SceneTemplate(
|
| 197 |
+
key="vintage_retro",
|
| 198 |
+
name="Vintage Retro",
|
| 199 |
+
prompt="vintage retro aesthetic background, warm sepia tones, nostalgic 70s vibes, film grain texture, classic photography style, timeless elegance",
|
| 200 |
+
negative_extra="modern, digital, cold, harsh",
|
| 201 |
+
category="Artistic",
|
| 202 |
+
icon="📻",
|
| 203 |
+
guidance_scale=7.0
|
| 204 |
+
),
|
| 205 |
+
"watercolor_dream": SceneTemplate(
|
| 206 |
+
key="watercolor_dream",
|
| 207 |
+
name="Watercolor Dream",
|
| 208 |
+
prompt="beautiful watercolor painting background, soft flowing colors, artistic brush strokes, dreamy ethereal atmosphere, delicate artistic aesthetic",
|
| 209 |
+
negative_extra="digital, sharp, photorealistic",
|
| 210 |
+
category="Artistic",
|
| 211 |
+
icon="🖌️",
|
| 212 |
+
guidance_scale=6.5
|
| 213 |
+
),
|
| 214 |
+
|
| 215 |
+
# Seasonal Category
|
| 216 |
+
"autumn_foliage": SceneTemplate(
|
| 217 |
+
key="autumn_foliage",
|
| 218 |
+
name="Autumn Foliage",
|
| 219 |
+
prompt="beautiful autumn scenery, vibrant fall foliage, orange red and golden leaves, maple trees, warm sunlight filtering through, cozy seasonal atmosphere",
|
| 220 |
+
negative_extra="spring, summer, green, snow",
|
| 221 |
+
category="Seasonal",
|
| 222 |
+
icon="🍂",
|
| 223 |
+
guidance_scale=7.0
|
| 224 |
+
),
|
| 225 |
+
"winter_snow": SceneTemplate(
|
| 226 |
+
key="winter_snow",
|
| 227 |
+
name="Winter Wonderland",
|
| 228 |
+
prompt="magical winter wonderland, fresh white snow covering everything, snow-laden pine trees, soft snowfall, peaceful cold atmosphere, holiday season vibes",
|
| 229 |
+
negative_extra="summer, green, rain, mud",
|
| 230 |
+
category="Seasonal",
|
| 231 |
+
icon="❄️",
|
| 232 |
+
guidance_scale=7.0
|
| 233 |
+
),
|
| 234 |
+
"summer_tropical": SceneTemplate(
|
| 235 |
+
key="summer_tropical",
|
| 236 |
+
name="Tropical Summer",
|
| 237 |
+
prompt="vibrant tropical summer scene, lush palm trees, bright sunny day, exotic flowers, paradise vacation destination, warm and inviting atmosphere",
|
| 238 |
+
negative_extra="winter, cold, snow, gray",
|
| 239 |
+
category="Seasonal",
|
| 240 |
+
icon="🌴",
|
| 241 |
+
guidance_scale=7.0
|
| 242 |
+
),
|
| 243 |
+
"spring_meadow": SceneTemplate(
|
| 244 |
+
key="spring_meadow",
|
| 245 |
+
name="Spring Meadow",
|
| 246 |
+
prompt="beautiful spring meadow, wildflowers blooming, fresh green grass, butterflies, soft warm sunlight, renewal and new beginnings, pastoral beauty",
|
| 247 |
+
negative_extra="winter, autumn, dead, dry",
|
| 248 |
+
category="Seasonal",
|
| 249 |
+
icon="🌷",
|
| 250 |
+
guidance_scale=7.0
|
| 251 |
+
),
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
# Category display order
|
| 255 |
+
CATEGORIES = ["Professional", "Nature", "Urban", "Artistic", "Seasonal"]
|
| 256 |
+
|
| 257 |
+
def __init__(self):
|
| 258 |
+
"""Initialize the scene template manager"""
|
| 259 |
+
logger.info(f"SceneTemplateManager initialized with {len(self.TEMPLATES)} templates")
|
| 260 |
+
|
| 261 |
+
def get_all_templates(self) -> Dict[str, SceneTemplate]:
|
| 262 |
+
"""Get all available templates"""
|
| 263 |
+
return self.TEMPLATES
|
| 264 |
+
|
| 265 |
+
def get_template(self, key: str) -> Optional[SceneTemplate]:
|
| 266 |
+
"""Get a specific template by key"""
|
| 267 |
+
return self.TEMPLATES.get(key)
|
| 268 |
+
|
| 269 |
+
def get_templates_by_category(self, category: str) -> List[SceneTemplate]:
|
| 270 |
+
"""Get all templates in a specific category"""
|
| 271 |
+
return [t for t in self.TEMPLATES.values() if t.category == category]
|
| 272 |
+
|
| 273 |
+
def get_categories(self) -> List[str]:
|
| 274 |
+
"""Get list of all categories in display order"""
|
| 275 |
+
return self.CATEGORIES
|
| 276 |
+
|
| 277 |
+
def get_template_choices_sorted(self) -> List[str]:
|
| 278 |
+
"""
|
| 279 |
+
Get template choices formatted for Gradio dropdown.
|
| 280 |
+
Returns list of display strings sorted A-Z: "🏢 Modern Office"
|
| 281 |
+
"""
|
| 282 |
+
display_list = []
|
| 283 |
+
for key, template in self.TEMPLATES.items():
|
| 284 |
+
display_name = f"{template.icon} {template.name}"
|
| 285 |
+
display_list.append(display_name)
|
| 286 |
+
|
| 287 |
+
# Sort alphabetically by name (ignoring emoji)
|
| 288 |
+
display_list.sort(key=lambda x: x.split(' ', 1)[1] if ' ' in x else x)
|
| 289 |
+
return display_list
|
| 290 |
+
|
| 291 |
+
def get_template_key_from_display(self, display_name: str) -> Optional[str]:
|
| 292 |
+
"""
|
| 293 |
+
Get template key from display name.
|
| 294 |
+
Example: "🏢 Modern Office" -> "office_modern"
|
| 295 |
+
"""
|
| 296 |
+
if not display_name:
|
| 297 |
+
return None
|
| 298 |
+
|
| 299 |
+
for key, template in self.TEMPLATES.items():
|
| 300 |
+
if f"{template.icon} {template.name}" == display_name:
|
| 301 |
+
return key
|
| 302 |
+
return None
|
| 303 |
+
|
| 304 |
+
def get_prompt_for_template(self, key: str) -> Optional[str]:
|
| 305 |
+
"""Get the prompt string for a template"""
|
| 306 |
+
template = self.get_template(key)
|
| 307 |
+
return template.prompt if template else None
|
| 308 |
+
|
| 309 |
+
def get_negative_prompt_for_template(
|
| 310 |
+
self,
|
| 311 |
+
key: str,
|
| 312 |
+
base_negative: str = "blurry, low quality, distorted, people, characters"
|
| 313 |
+
) -> str:
|
| 314 |
+
"""Get combined negative prompt for a template"""
|
| 315 |
+
template = self.get_template(key)
|
| 316 |
+
if template and template.negative_extra:
|
| 317 |
+
return f"{base_negative}, {template.negative_extra}"
|
| 318 |
+
return base_negative
|
| 319 |
+
|
| 320 |
+
def get_guidance_scale_for_template(self, key: str) -> float:
|
| 321 |
+
"""Get the recommended guidance scale for a template"""
|
| 322 |
+
template = self.get_template(key)
|
| 323 |
+
return template.guidance_scale if template else 7.5
|
| 324 |
+
|
| 325 |
+
def build_gallery_html(self) -> str:
|
| 326 |
+
"""
|
| 327 |
+
Build HTML for the scene template gallery.
|
| 328 |
+
Returns HTML string for display in Gradio.
|
| 329 |
+
"""
|
| 330 |
+
html_parts = ['<div class="scene-gallery">']
|
| 331 |
+
|
| 332 |
+
for category in self.CATEGORIES:
|
| 333 |
+
templates = self.get_templates_by_category(category)
|
| 334 |
+
if not templates:
|
| 335 |
+
continue
|
| 336 |
+
|
| 337 |
+
html_parts.append(f'''
|
| 338 |
+
<div class="scene-category">
|
| 339 |
+
<h4 class="scene-category-title">{category}</h4>
|
| 340 |
+
<div class="scene-grid">
|
| 341 |
+
''')
|
| 342 |
+
|
| 343 |
+
for template in templates:
|
| 344 |
+
html_parts.append(f'''
|
| 345 |
+
<button class="scene-card" data-template="{template.key}" onclick="selectTemplate('{template.key}')">
|
| 346 |
+
<span class="scene-icon">{template.icon}</span>
|
| 347 |
+
<span class="scene-name">{template.name}</span>
|
| 348 |
+
</button>
|
| 349 |
+
''')
|
| 350 |
+
|
| 351 |
+
html_parts.append('</div></div>')
|
| 352 |
+
|
| 353 |
+
html_parts.append('</div>')
|
| 354 |
+
return ''.join(html_parts)
|
| 355 |
+
|
| 356 |
+
def get_gallery_css(self) -> str:
|
| 357 |
+
"""Get CSS styles for the scene gallery"""
|
| 358 |
+
return """
|
| 359 |
+
/* Scene Gallery Styles */
|
| 360 |
+
.scene-gallery {
|
| 361 |
+
margin: 16px 0;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
.scene-category {
|
| 365 |
+
margin-bottom: 20px;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
.scene-category-title {
|
| 369 |
+
font-size: 0.9rem;
|
| 370 |
+
font-weight: 600;
|
| 371 |
+
color: #475569;
|
| 372 |
+
margin-bottom: 12px;
|
| 373 |
+
padding-bottom: 8px;
|
| 374 |
+
border-bottom: 1px solid #e2e8f0;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
.scene-grid {
|
| 378 |
+
display: grid;
|
| 379 |
+
grid-template-columns: repeat(auto-fill, minmax(100px, 1fr));
|
| 380 |
+
gap: 8px;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
.scene-card {
|
| 384 |
+
display: flex;
|
| 385 |
+
flex-direction: column;
|
| 386 |
+
align-items: center;
|
| 387 |
+
justify-content: center;
|
| 388 |
+
padding: 12px 8px;
|
| 389 |
+
background: #f8fafc;
|
| 390 |
+
border: 1px solid #e2e8f0;
|
| 391 |
+
border-radius: 8px;
|
| 392 |
+
cursor: pointer;
|
| 393 |
+
transition: all 0.2s ease;
|
| 394 |
+
min-height: 70px;
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
.scene-card:hover {
|
| 398 |
+
background: #dbeafe;
|
| 399 |
+
border-color: #3b82f6;
|
| 400 |
+
transform: translateY(-2px);
|
| 401 |
+
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
.scene-card.selected {
|
| 405 |
+
background: #dbeafe;
|
| 406 |
+
border-color: #3b82f6;
|
| 407 |
+
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
.scene-icon {
|
| 411 |
+
font-size: 1.5rem;
|
| 412 |
+
margin-bottom: 4px;
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
.scene-name {
|
| 416 |
+
font-size: 0.75rem;
|
| 417 |
+
font-weight: 500;
|
| 418 |
+
color: #1e293b;
|
| 419 |
+
text-align: center;
|
| 420 |
+
line-height: 1.2;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
@media (max-width: 768px) {
|
| 424 |
+
.scene-grid {
|
| 425 |
+
grid-template-columns: repeat(3, 1fr);
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
+
"""
|
ui_manager.py
CHANGED
|
@@ -1,20 +1,35 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
-
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
| 4 |
from FlowFacade import FlowFacade
|
|
|
|
|
|
|
| 5 |
from css_style import DELTAFLOW_CSS
|
| 6 |
from prompt_examples import PROMPT_EXAMPLES
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class UIManager:
|
| 10 |
-
def __init__(self, facade: FlowFacade):
|
| 11 |
self.facade = facade
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def create_interface(self) -> gr.Blocks:
|
| 14 |
with gr.Blocks(
|
| 15 |
theme=gr.themes.Soft(),
|
| 16 |
css=DELTAFLOW_CSS,
|
| 17 |
-
title="VividFlow -
|
| 18 |
) as interface:
|
| 19 |
|
| 20 |
# Header
|
|
@@ -22,276 +37,523 @@ class UIManager:
|
|
| 22 |
<div class="header-container">
|
| 23 |
<h1 class="header-title">🌊 VividFlow</h1>
|
| 24 |
<p class="header-subtitle">
|
| 25 |
-
|
| 26 |
-
Transform
|
| 27 |
</p>
|
| 28 |
</div>
|
| 29 |
""")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
label="💡 Quick Prompt Category",
|
| 60 |
-
value="💃 Fashion / Beauty (Facial Only)",
|
| 61 |
-
interactive=True
|
| 62 |
)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
value=None,
|
| 68 |
-
interactive=True
|
| 69 |
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
gr.HTML("""
|
| 83 |
-
<div
|
| 84 |
-
<strong
|
| 85 |
-
|
| 86 |
-
|
| 87 |
</div>
|
| 88 |
""")
|
| 89 |
|
| 90 |
-
|
| 91 |
-
"
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
value=3.0,
|
| 104 |
-
label="Duration (seconds)",
|
| 105 |
-
info="3.0s = 49 frames, 5.0s = 81 frames (16fps)"
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
steps_slider = gr.Slider(
|
| 109 |
-
minimum=4,
|
| 110 |
-
maximum=12,
|
| 111 |
-
step=1,
|
| 112 |
-
value=4,
|
| 113 |
-
label="Inference Steps",
|
| 114 |
-
info="4-6 recommended • Higher steps = longer generation time"
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
with gr.Row():
|
| 118 |
-
guidance_scale = gr.Slider(
|
| 119 |
-
minimum=0.0,
|
| 120 |
-
maximum=5.0,
|
| 121 |
-
step=0.5,
|
| 122 |
-
value=1.0,
|
| 123 |
-
label="Guidance Scale (high noise)"
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
guidance_scale_2 = gr.Slider(
|
| 127 |
-
minimum=0.0,
|
| 128 |
-
maximum=5.0,
|
| 129 |
-
step=0.5,
|
| 130 |
-
value=1.0,
|
| 131 |
-
label="Guidance Scale (low noise)"
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
with gr.Row():
|
| 135 |
-
seed_input = gr.Number(
|
| 136 |
-
label="Seed",
|
| 137 |
-
value=42,
|
| 138 |
-
precision=0,
|
| 139 |
-
minimum=0,
|
| 140 |
-
maximum=2147483647,
|
| 141 |
-
info="Use same seed for reproducible results"
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
randomize_seed = gr.Checkbox(
|
| 145 |
-
label="Randomize Seed",
|
| 146 |
-
value=True,
|
| 147 |
-
info="Generate different results each time"
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
enable_ai_prompt = gr.Checkbox(
|
| 151 |
-
label="🤖 Enable AI Prompt Expansion (Qwen2.5)",
|
| 152 |
-
value=False,
|
| 153 |
-
info="Use AI to enhance your prompt (adds ~30s)"
|
| 154 |
-
)
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
label="
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
interactive=False,
|
| 178 |
-
scale=1
|
| 179 |
-
)
|
| 180 |
|
| 181 |
-
# Info section
|
| 182 |
-
with gr.Row():
|
| 183 |
gr.HTML("""
|
| 184 |
-
<div class="
|
| 185 |
-
<strong
|
| 186 |
-
|
| 187 |
-
• <strong>Works with ANY image:</strong> Fashion portraits, anime, landscapes, products, abstract art, etc.<br>
|
| 188 |
-
• <strong>For dramatic effects:</strong> Choose prompts with words like "explosive", "dramatic", "swirls", "transforms"<br>
|
| 189 |
-
• <strong>Image quality matters:</strong> Higher resolution and clear subjects produce better results
|
| 190 |
</div>
|
| 191 |
""")
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
outputs=[example_dropdown])
|
| 225 |
-
example_dropdown.change(fn=fill_prompt, inputs=[example_dropdown],
|
| 226 |
-
outputs=[prompt_input])
|
| 227 |
-
image_input.change(fn=show_resolution_info, inputs=[image_input],
|
| 228 |
-
outputs=[resolution_info])
|
| 229 |
-
|
| 230 |
-
generate_btn.click(
|
| 231 |
-
fn=self._handle_generation,
|
| 232 |
-
inputs=[
|
| 233 |
-
image_input,
|
| 234 |
-
prompt_input,
|
| 235 |
-
duration_slider,
|
| 236 |
-
steps_slider,
|
| 237 |
-
guidance_scale,
|
| 238 |
-
guidance_scale_2,
|
| 239 |
-
seed_input,
|
| 240 |
-
randomize_seed,
|
| 241 |
-
enable_ai_prompt
|
| 242 |
-
],
|
| 243 |
-
outputs=[video_output, prompt_output, seed_output],
|
| 244 |
-
show_progress=True
|
| 245 |
-
)
|
| 246 |
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
def _handle_generation(self, image: Image.Image, prompt: str, duration: float,
|
| 250 |
-
steps: int, guidance_1: float, guidance_2: float, seed: int,
|
| 251 |
-
randomize: bool, enable_ai: bool,
|
| 252 |
-
progress=gr.Progress()) -> Tuple[str, str, int]:
|
| 253 |
try:
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
raise gr.Error("❌ Image dimensions invalid (256-4096px)")
|
| 260 |
|
| 261 |
-
|
| 262 |
-
image
|
| 263 |
-
|
| 264 |
-
duration_seconds=duration,
|
| 265 |
-
num_inference_steps=steps,
|
| 266 |
-
guidance_scale=guidance_1,
|
| 267 |
-
guidance_scale_2=guidance_2,
|
| 268 |
-
seed=int(seed),
|
| 269 |
-
randomize_seed=randomize,
|
| 270 |
-
enable_prompt_expansion=enable_ai,
|
| 271 |
-
progress=progress
|
| 272 |
)
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
except Exception as e:
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
-
def launch(self, share: bool = False, server_name: str = "0.0.0.0",
|
| 294 |
-
server_port: int = None, **kwargs) -> None:
|
| 295 |
-
interface = self.create_interface()
|
| 296 |
-
interface.launch(share=share, server_name=server_name,
|
| 297 |
-
server_port=server_port, **kwargs)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
+
from typing import Tuple, Optional, Dict, Any
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
from FlowFacade import FlowFacade
|
| 8 |
+
from BackgroundEngine import BackgroundEngine
|
| 9 |
+
from scene_templates import SceneTemplateManager
|
| 10 |
from css_style import DELTAFLOW_CSS
|
| 11 |
from prompt_examples import PROMPT_EXAMPLES
|
| 12 |
|
| 13 |
+
try:
|
| 14 |
+
import spaces
|
| 15 |
+
SPACES_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
SPACES_AVAILABLE = False
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
|
| 22 |
class UIManager:
|
| 23 |
+
def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine):
|
| 24 |
self.facade = facade
|
| 25 |
+
self.background_engine = background_engine
|
| 26 |
+
self.template_manager = SceneTemplateManager()
|
| 27 |
|
| 28 |
def create_interface(self) -> gr.Blocks:
|
| 29 |
with gr.Blocks(
|
| 30 |
theme=gr.themes.Soft(),
|
| 31 |
css=DELTAFLOW_CSS,
|
| 32 |
+
title="VividFlow - AI Image Enhancement & Video Generation"
|
| 33 |
) as interface:
|
| 34 |
|
| 35 |
# Header
|
|
|
|
| 37 |
<div class="header-container">
|
| 38 |
<h1 class="header-title">🌊 VividFlow</h1>
|
| 39 |
<p class="header-subtitle">
|
| 40 |
+
AI-Powered Image Enhancement & Video Generation<br>
|
| 41 |
+
Transform images with background replacement, then bring them to life with AI
|
| 42 |
</p>
|
| 43 |
</div>
|
| 44 |
""")
|
| 45 |
|
| 46 |
+
# Main Tabs
|
| 47 |
+
with gr.Tabs() as main_tabs:
|
| 48 |
+
|
| 49 |
+
# Tab 1: Image to Video (Original Functionality)
|
| 50 |
+
with gr.Tab("🎬 Image to Video"):
|
| 51 |
+
self._create_i2v_tab()
|
| 52 |
+
|
| 53 |
+
# Tab 2: Background Generation (New Feature)
|
| 54 |
+
with gr.Tab("🎨 Background Generation"):
|
| 55 |
+
self._create_background_tab()
|
| 56 |
+
|
| 57 |
+
# Footer
|
| 58 |
+
gr.HTML("""
|
| 59 |
+
<div class="footer">
|
| 60 |
+
<p>Powered by Wan2.2-I2V-A14B, SDXL, and OpenCLIP | Built with Gradio</p>
|
| 61 |
+
</div>
|
| 62 |
+
""")
|
| 63 |
+
|
| 64 |
+
return interface
|
| 65 |
|
| 66 |
+
def _create_i2v_tab(self):
|
| 67 |
+
"""Create Image to Video tab (original VividFlow functionality)"""
|
| 68 |
+
with gr.Row():
|
| 69 |
+
# Left Panel: Input
|
| 70 |
+
with gr.Column(scale=1, elem_classes="input-card"):
|
| 71 |
+
gr.Markdown("### 📤 Input")
|
| 72 |
+
|
| 73 |
+
image_input = gr.Image(
|
| 74 |
+
label="Upload Image (any type: photo, art, cartoon, etc.)",
|
| 75 |
+
type="pil",
|
| 76 |
+
elem_classes="image-upload",
|
| 77 |
+
height=320
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
resolution_info = gr.Markdown(
|
| 81 |
+
value="",
|
| 82 |
+
visible=False,
|
| 83 |
+
elem_classes="info-text"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
prompt_input = gr.Textbox(
|
| 87 |
+
label="Motion Instruction",
|
| 88 |
+
placeholder="Describe camera movements and subject actions...",
|
| 89 |
+
lines=3,
|
| 90 |
+
max_lines=6
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
category_dropdown = gr.Dropdown(
|
| 94 |
+
choices=list(PROMPT_EXAMPLES.keys()),
|
| 95 |
+
label="💡 Quick Prompt Category",
|
| 96 |
+
value="💃 Fashion / Beauty (Facial Only)",
|
| 97 |
+
interactive=True
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
example_dropdown = gr.Dropdown(
|
| 101 |
+
choices=PROMPT_EXAMPLES["💃 Fashion / Beauty (Facial Only)"],
|
| 102 |
+
label="Example Prompts (click to use)",
|
| 103 |
+
value=None,
|
| 104 |
+
interactive=True
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
gr.HTML("""
|
| 108 |
+
<div class="quality-banner">
|
| 109 |
+
<strong>💡 Choose the Right Prompt Category:</strong><br>
|
| 110 |
+
• <strong>💃 Facial Only:</strong> Safe for headshots without visible hands<br>
|
| 111 |
+
• <strong>🙌 Hands Visible Required:</strong> Only use if hands are fully visible<br>
|
| 112 |
+
• <strong>🌄 Scenery/Objects:</strong> For landscapes, products, abstract content
|
| 113 |
+
</div>
|
| 114 |
+
""")
|
| 115 |
+
|
| 116 |
+
gr.HTML("""
|
| 117 |
+
<div class="patience-banner">
|
| 118 |
+
<strong>⏱️ First-time loading may take a moment!</strong><br>
|
| 119 |
+
Subsequent runs will be much faster.
|
| 120 |
+
</div>
|
| 121 |
+
""")
|
| 122 |
+
|
| 123 |
+
generate_btn = gr.Button(
|
| 124 |
+
"🎬 Generate Video",
|
| 125 |
+
variant="primary",
|
| 126 |
+
elem_classes="primary-button",
|
| 127 |
+
size="lg"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 131 |
+
duration_slider = gr.Slider(
|
| 132 |
+
minimum=0.5,
|
| 133 |
+
maximum=5.0,
|
| 134 |
+
value=3.0,
|
| 135 |
+
step=0.5,
|
| 136 |
+
label="Video Duration (seconds)"
|
| 137 |
)
|
| 138 |
|
| 139 |
+
steps_slider = gr.Slider(
|
| 140 |
+
minimum=4,
|
| 141 |
+
maximum=25,
|
| 142 |
+
value=4,
|
| 143 |
+
step=1,
|
| 144 |
+
label="Quality Steps (4=Lightning Fast, 8-25=Higher Quality)"
|
| 145 |
)
|
| 146 |
|
| 147 |
+
fps_slider = gr.Slider(
|
| 148 |
+
minimum=8,
|
| 149 |
+
maximum=24,
|
| 150 |
+
value=16,
|
| 151 |
+
step=1,
|
| 152 |
+
label="Frames Per Second"
|
| 153 |
)
|
| 154 |
|
| 155 |
+
expand_prompt = gr.Checkbox(
|
| 156 |
+
label="AI Prompt Expansion (experimental)",
|
| 157 |
+
value=False
|
|
|
|
|
|
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
+
randomize_seed = gr.Checkbox(
|
| 161 |
+
label="Randomize Seed",
|
| 162 |
+
value=True
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
+
seed_input = gr.Number(
|
| 166 |
+
label="Manual Seed (if not randomized)",
|
| 167 |
+
value=42,
|
| 168 |
+
precision=0
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Right Panel: Output
|
| 172 |
+
with gr.Column(scale=1, elem_classes="output-card"):
|
| 173 |
+
gr.Markdown("### 🎥 Output")
|
| 174 |
+
|
| 175 |
+
video_output = gr.Video(
|
| 176 |
+
label="Generated Video",
|
| 177 |
+
elem_classes="video-player"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
final_prompt_output = gr.Textbox(
|
| 181 |
+
label="Final Prompt Used",
|
| 182 |
+
interactive=False,
|
| 183 |
+
lines=2
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
seed_output = gr.Number(
|
| 187 |
+
label="Seed Used",
|
| 188 |
+
interactive=False,
|
| 189 |
+
precision=0
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Event handlers for I2V tab
|
| 193 |
+
def update_resolution_display(img):
|
| 194 |
+
if img is None:
|
| 195 |
+
return gr.update(visible=False)
|
| 196 |
+
w, h = img.size
|
| 197 |
+
new_w = (w // 16) * 16
|
| 198 |
+
new_h = (h // 16) * 16
|
| 199 |
+
return gr.update(
|
| 200 |
+
value=f"📐 **Resolution:** Input: {w}×{h} → Output: {new_w}×{new_h}",
|
| 201 |
+
visible=True
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def category_changed(category):
|
| 205 |
+
if category in PROMPT_EXAMPLES:
|
| 206 |
+
return gr.update(choices=PROMPT_EXAMPLES[category], value=None)
|
| 207 |
+
return gr.update()
|
| 208 |
+
|
| 209 |
+
def example_selected(example):
|
| 210 |
+
return example if example else ""
|
| 211 |
+
|
| 212 |
+
image_input.change(
|
| 213 |
+
fn=update_resolution_display,
|
| 214 |
+
inputs=[image_input],
|
| 215 |
+
outputs=[resolution_info]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
category_dropdown.change(
|
| 219 |
+
fn=category_changed,
|
| 220 |
+
inputs=[category_dropdown],
|
| 221 |
+
outputs=[example_dropdown]
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
example_dropdown.change(
|
| 225 |
+
fn=example_selected,
|
| 226 |
+
inputs=[example_dropdown],
|
| 227 |
+
outputs=[prompt_input]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
generate_btn.click(
|
| 231 |
+
fn=self._generate_video_handler,
|
| 232 |
+
inputs=[
|
| 233 |
+
image_input, prompt_input, duration_slider,
|
| 234 |
+
steps_slider, fps_slider, expand_prompt,
|
| 235 |
+
randomize_seed, seed_input
|
| 236 |
+
],
|
| 237 |
+
outputs=[video_output, final_prompt_output, seed_output]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def _generate_video_handler(
|
| 241 |
+
self,
|
| 242 |
+
image: Image.Image,
|
| 243 |
+
prompt: str,
|
| 244 |
+
duration: float,
|
| 245 |
+
steps: int,
|
| 246 |
+
fps: int,
|
| 247 |
+
expand_prompt: bool,
|
| 248 |
+
randomize_seed: bool,
|
| 249 |
+
seed: int
|
| 250 |
+
) -> Tuple[str, str, int]:
|
| 251 |
+
"""Handler for video generation"""
|
| 252 |
+
if image is None:
|
| 253 |
+
return None, "Please upload an image", 0
|
| 254 |
+
|
| 255 |
+
if not prompt.strip():
|
| 256 |
+
return None, "Please provide a motion prompt", 0
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
video_path, final_prompt, seed_used = self.facade.generate_video_from_image(
|
| 260 |
+
image=image,
|
| 261 |
+
user_instruction=prompt,
|
| 262 |
+
duration_seconds=duration,
|
| 263 |
+
num_inference_steps=steps,
|
| 264 |
+
enable_prompt_expansion=expand_prompt,
|
| 265 |
+
randomize_seed=randomize_seed,
|
| 266 |
+
seed=seed
|
| 267 |
+
)
|
| 268 |
+
return video_path, final_prompt, seed_used
|
| 269 |
+
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.error(f"Video generation failed: {e}")
|
| 272 |
+
return None, f"Error: {str(e)}", 0
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _create_background_tab(self):
|
| 276 |
+
"""Create Background Generation tab (SceneWeaver functionality)"""
|
| 277 |
+
with gr.Row():
|
| 278 |
+
# Left Panel: Input
|
| 279 |
+
with gr.Column(scale=1, elem_classes="feature-card"):
|
| 280 |
+
gr.Markdown("### 📸 Upload & Configure")
|
| 281 |
|
| 282 |
+
gr.HTML("""
|
| 283 |
+
<div class="quality-banner">
|
| 284 |
+
<strong>💡 Best Results Tips:</strong><br>
|
| 285 |
+
• Clean portrait photos with simple backgrounds work best<br>
|
| 286 |
+
• Complex scenes (e.g., pets with grass) may need parameter adjustments<br>
|
| 287 |
+
• Use Advanced Options below to fine-tune edge blending
|
| 288 |
+
</div>
|
| 289 |
+
""")
|
| 290 |
+
|
| 291 |
+
bg_image_input = gr.Image(
|
| 292 |
+
label="Upload Your Image",
|
| 293 |
+
type="pil",
|
| 294 |
+
height=280
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Scene Template Selector
|
| 298 |
+
template_dropdown = gr.Dropdown(
|
| 299 |
+
label="Scene Templates (24 curated scenes A-Z)",
|
| 300 |
+
choices=[""] + self.template_manager.get_template_choices_sorted(),
|
| 301 |
+
value="",
|
| 302 |
+
info="Optional: Select a preset or describe your own",
|
| 303 |
+
elem_classes=["template-dropdown"]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
bg_prompt_input = gr.Textbox(
|
| 307 |
+
label="Background Scene Description",
|
| 308 |
+
placeholder="Select a template above or describe your own scene...",
|
| 309 |
+
lines=3
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
combination_mode = gr.Dropdown(
|
| 313 |
+
label="Composition Mode",
|
| 314 |
+
choices=["center", "left_half", "right_half", "full"],
|
| 315 |
+
value="center",
|
| 316 |
+
info="center=Smart Center | full=Full Image"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
focus_mode = gr.Dropdown(
|
| 320 |
+
label="Focus Mode",
|
| 321 |
+
choices=["person", "scene"],
|
| 322 |
+
value="person",
|
| 323 |
+
info="person=Tight Crop | scene=Include Surrounding"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 327 |
gr.HTML("""
|
| 328 |
+
<div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
|
| 329 |
+
<strong>💡 When to Adjust:</strong><br>
|
| 330 |
+
• <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
|
| 331 |
+
• <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced. Helps diagnose edge issues.
|
| 332 |
</div>
|
| 333 |
""")
|
| 334 |
|
| 335 |
+
feather_radius_slider = gr.Slider(
|
| 336 |
+
label="Feather Radius (Edge Softness)",
|
| 337 |
+
minimum=0,
|
| 338 |
+
maximum=20,
|
| 339 |
+
value=0,
|
| 340 |
+
step=1,
|
| 341 |
+
info="Softens mask edges. Try 5-10 if edges look harsh."
|
| 342 |
)
|
| 343 |
|
| 344 |
+
bg_negative_prompt = gr.Textbox(
|
| 345 |
+
label="Negative Prompt",
|
| 346 |
+
value="blurry, low quality, distorted, people, characters",
|
| 347 |
+
lines=2,
|
| 348 |
+
info="Prevents unwanted elements in background"
|
| 349 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
bg_steps_slider = gr.Slider(
|
| 352 |
+
label="Quality Steps",
|
| 353 |
+
minimum=15,
|
| 354 |
+
maximum=50,
|
| 355 |
+
value=25,
|
| 356 |
+
step=5,
|
| 357 |
+
info="Higher = better quality but slower"
|
| 358 |
+
)
|
| 359 |
|
| 360 |
+
bg_guidance_slider = gr.Slider(
|
| 361 |
+
label="Guidance Scale",
|
| 362 |
+
minimum=5.0,
|
| 363 |
+
maximum=15.0,
|
| 364 |
+
value=7.5,
|
| 365 |
+
step=0.5,
|
| 366 |
+
info="How strictly to follow prompt"
|
| 367 |
)
|
| 368 |
|
| 369 |
+
generate_bg_btn = gr.Button(
|
| 370 |
+
"🎨 Generate Background",
|
| 371 |
+
variant="primary",
|
| 372 |
+
elem_classes="primary-button",
|
| 373 |
+
size="lg"
|
| 374 |
+
)
|
|
|
|
| 375 |
|
| 376 |
+
# Right Panel: Output
|
| 377 |
+
with gr.Column(scale=2, elem_classes="feature-card"):
|
| 378 |
+
gr.Markdown("### 🎭 Results Gallery")
|
|
|
|
|
|
|
|
|
|
| 379 |
|
|
|
|
|
|
|
| 380 |
gr.HTML("""
|
| 381 |
+
<div class="patience-banner">
|
| 382 |
+
<strong>⏱️ First-time users:</strong> Initial model loading takes 1-2 minutes.
|
| 383 |
+
Subsequent generations are much faster (~30s).
|
|
|
|
|
|
|
|
|
|
| 384 |
</div>
|
| 385 |
""")
|
| 386 |
|
| 387 |
+
with gr.Tabs():
|
| 388 |
+
with gr.TabItem("Final Result"):
|
| 389 |
+
bg_combined_output = gr.Image(
|
| 390 |
+
label="Your Generated Image",
|
| 391 |
+
elem_classes=["result-gallery"]
|
| 392 |
+
)
|
| 393 |
+
with gr.TabItem("Background"):
|
| 394 |
+
bg_generated_output = gr.Image(
|
| 395 |
+
label="Generated Background",
|
| 396 |
+
elem_classes=["result-gallery"]
|
| 397 |
+
)
|
| 398 |
+
with gr.TabItem("Original"):
|
| 399 |
+
bg_original_output = gr.Image(
|
| 400 |
+
label="Processed Original",
|
| 401 |
+
elem_classes=["result-gallery"]
|
| 402 |
+
)
|
| 403 |
+
with gr.TabItem("Mask Preview"):
|
| 404 |
+
gr.HTML("""
|
| 405 |
+
<div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 8px; font-size: 13px;">
|
| 406 |
+
<strong>📐 How to Read:</strong> White = Original kept | Black = Background replaced<br>
|
| 407 |
+
Use this to diagnose edge quality. If edges are too harsh, increase Feather Radius.
|
| 408 |
+
</div>
|
| 409 |
+
""")
|
| 410 |
+
bg_mask_output = gr.Image(
|
| 411 |
+
label="Blending Mask",
|
| 412 |
+
elem_classes=["result-gallery"]
|
| 413 |
+
)
|
| 414 |
|
| 415 |
+
bg_status_output = gr.Textbox(
|
| 416 |
+
label="Status",
|
| 417 |
+
value="Ready to create! Upload an image and describe your vision.",
|
| 418 |
+
interactive=False,
|
| 419 |
+
elem_classes=["status-panel"]
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
with gr.Row():
|
| 423 |
+
clear_bg_btn = gr.Button(
|
| 424 |
+
"Clear All",
|
| 425 |
+
elem_classes=["secondary-button"]
|
| 426 |
+
)
|
| 427 |
+
memory_btn = gr.Button(
|
| 428 |
+
"Clean Memory",
|
| 429 |
+
elem_classes=["secondary-button"]
|
| 430 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# Event handlers for Background Generation tab
|
| 433 |
+
def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 434 |
+
if not display_name:
|
| 435 |
+
return "", current_negative, 7.5
|
| 436 |
+
|
| 437 |
+
template_key = self.template_manager.get_template_key_from_display(display_name)
|
| 438 |
+
if not template_key:
|
| 439 |
+
return "", current_negative, 7.5
|
| 440 |
+
|
| 441 |
+
template = self.template_manager.get_template(template_key)
|
| 442 |
+
if template:
|
| 443 |
+
prompt = template.prompt
|
| 444 |
+
negative = self.template_manager.get_negative_prompt_for_template(
|
| 445 |
+
template_key, current_negative
|
| 446 |
+
)
|
| 447 |
+
guidance = template.guidance_scale
|
| 448 |
+
return prompt, negative, guidance
|
| 449 |
+
|
| 450 |
+
return "", current_negative, 7.5
|
| 451 |
+
|
| 452 |
+
template_dropdown.change(
|
| 453 |
+
fn=apply_template,
|
| 454 |
+
inputs=[template_dropdown, bg_negative_prompt],
|
| 455 |
+
outputs=[bg_prompt_input, bg_negative_prompt, bg_guidance_slider]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
generate_bg_btn.click(
|
| 459 |
+
fn=self._generate_background_handler,
|
| 460 |
+
inputs=[
|
| 461 |
+
bg_image_input, bg_prompt_input, combination_mode,
|
| 462 |
+
focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
|
| 463 |
+
feather_radius_slider
|
| 464 |
+
],
|
| 465 |
+
outputs=[
|
| 466 |
+
bg_combined_output, bg_generated_output,
|
| 467 |
+
bg_original_output, bg_mask_output, bg_status_output
|
| 468 |
+
]
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
clear_bg_btn.click(
|
| 472 |
+
fn=lambda: (None, None, None, None, "Ready to create!"),
|
| 473 |
+
outputs=[
|
| 474 |
+
bg_combined_output, bg_generated_output,
|
| 475 |
+
bg_original_output, bg_mask_output, bg_status_output
|
| 476 |
+
]
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
memory_btn.click(
|
| 480 |
+
fn=lambda: self.background_engine._memory_cleanup() or "Memory cleaned!",
|
| 481 |
+
outputs=[bg_status_output]
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def _generate_background_handler(
|
| 485 |
+
self,
|
| 486 |
+
image: Image.Image,
|
| 487 |
+
prompt: str,
|
| 488 |
+
combination_mode: str,
|
| 489 |
+
focus_mode: str,
|
| 490 |
+
negative_prompt: str,
|
| 491 |
+
steps: int,
|
| 492 |
+
guidance: float,
|
| 493 |
+
feather_radius: int
|
| 494 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
|
| 495 |
+
"""Handler for background generation"""
|
| 496 |
+
if image is None:
|
| 497 |
+
return None, None, None, None, "Please upload an image to get started!"
|
| 498 |
+
|
| 499 |
+
if not prompt.strip():
|
| 500 |
+
return None, None, None, None, "Please describe the background scene you'd like!"
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
try:
|
| 503 |
+
# Apply ZeroGPU decorator if available
|
| 504 |
+
if SPACES_AVAILABLE:
|
| 505 |
+
generate_fn = spaces.GPU(duration=60)(self._background_generate_core)
|
| 506 |
+
else:
|
| 507 |
+
generate_fn = self._background_generate_core
|
|
|
|
| 508 |
|
| 509 |
+
result = generate_fn(
|
| 510 |
+
image, prompt, combination_mode, focus_mode,
|
| 511 |
+
negative_prompt, steps, guidance, feather_radius
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
)
|
| 513 |
|
| 514 |
+
if result["success"]:
|
| 515 |
+
return (
|
| 516 |
+
result["combined_image"],
|
| 517 |
+
result["generated_scene"],
|
| 518 |
+
result["original_image"],
|
| 519 |
+
result["mask"],
|
| 520 |
+
"Image created successfully!"
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
error_msg = result.get("error", "Something went wrong")
|
| 524 |
+
return None, None, None, None, f"Error: {error_msg}"
|
| 525 |
|
| 526 |
except Exception as e:
|
| 527 |
+
logger.error(f"Background generation failed: {e}")
|
| 528 |
+
return None, None, None, None, f"Error: {str(e)}"
|
| 529 |
+
|
| 530 |
+
def _background_generate_core(
|
| 531 |
+
self,
|
| 532 |
+
image: Image.Image,
|
| 533 |
+
prompt: str,
|
| 534 |
+
combination_mode: str,
|
| 535 |
+
focus_mode: str,
|
| 536 |
+
negative_prompt: str,
|
| 537 |
+
steps: int,
|
| 538 |
+
guidance: float,
|
| 539 |
+
feather_radius: int
|
| 540 |
+
) -> Dict[str, Any]:
|
| 541 |
+
"""Core background generation with models"""
|
| 542 |
+
if not self.background_engine.is_initialized:
|
| 543 |
+
logger.info("Loading background generation models...")
|
| 544 |
+
self.background_engine.load_models()
|
| 545 |
+
|
| 546 |
+
result = self.background_engine.generate_and_combine(
|
| 547 |
+
original_image=image,
|
| 548 |
+
prompt=prompt,
|
| 549 |
+
combination_mode=combination_mode,
|
| 550 |
+
focus_mode=focus_mode,
|
| 551 |
+
negative_prompt=negative_prompt,
|
| 552 |
+
num_inference_steps=int(steps),
|
| 553 |
+
guidance_scale=float(guidance),
|
| 554 |
+
enable_prompt_enhancement=True,
|
| 555 |
+
feather_radius=int(feather_radius)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
return result
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|