DawnC commited on
Commit
6a2169d
·
verified ·
1 Parent(s): 84016c5

Upload 13 files

Browse files
BackgroundEngine.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import logging
6
+ import gc
7
+ import time
8
+ import os
9
+ from typing import Optional, Dict, Any, Callable
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
14
+ import open_clip
15
+ from mask_generator import MaskGenerator
16
+ from image_blender import ImageBlender
17
+
18
+ try:
19
+ import spaces
20
+ SPACES_AVAILABLE = True
21
+ except ImportError:
22
+ SPACES_AVAILABLE = False
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class BackgroundEngine:
28
+ """
29
+ Background generation engine for VividFlow.
30
+
31
+ Integrates SDXL pipeline, OpenCLIP analysis, mask generation,
32
+ and advanced image blending.
33
+ """
34
+
35
+ def __init__(self, device: str = "auto"):
36
+ self.device = self._setup_device(device)
37
+ self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
38
+ self.clip_model_name = "ViT-B-32"
39
+ self.clip_pretrained = "openai"
40
+
41
+ self.pipeline = None
42
+ self.clip_model = None
43
+ self.clip_preprocess = None
44
+ self.clip_tokenizer = None
45
+ self.is_initialized = False
46
+
47
+ self.max_image_size = 1024
48
+ self.default_steps = 25
49
+ self.use_fp16 = True
50
+
51
+ self.mask_generator = MaskGenerator(self.max_image_size)
52
+ self.image_blender = ImageBlender()
53
+
54
+ logger.info(f"BackgroundEngine initialized on {self.device}")
55
+
56
+ def _setup_device(self, device: str) -> str:
57
+ """Setup computation device (ZeroGPU compatible)"""
58
+ if os.getenv('SPACE_ID') is not None:
59
+ return "cpu"
60
+
61
+ if device == "auto":
62
+ if torch.cuda.is_available():
63
+ return "cuda"
64
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
65
+ return "mps"
66
+ return "cpu"
67
+ return device
68
+
69
+ def _memory_cleanup(self):
70
+ """Memory cleanup"""
71
+ for _ in range(3):
72
+ gc.collect()
73
+
74
+ is_spaces = os.getenv('SPACE_ID') is not None
75
+ if not is_spaces and torch.cuda.is_available():
76
+ torch.cuda.empty_cache()
77
+
78
+ def load_models(self, progress_callback: Optional[Callable] = None):
79
+ """Load SDXL and OpenCLIP models"""
80
+ if self.is_initialized:
81
+ logger.info("Models already loaded")
82
+ return
83
+
84
+ logger.info("Loading background generation models...")
85
+
86
+ try:
87
+ self._memory_cleanup()
88
+
89
+ # Detect actual device (in ZeroGPU, CUDA becomes available after @spaces.GPU allocation)
90
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
91
+ logger.info(f"Loading models to device: {actual_device}")
92
+
93
+ if progress_callback:
94
+ progress_callback("Loading OpenCLIP...", 20)
95
+
96
+ # Load OpenCLIP
97
+ self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
98
+ self.clip_model_name,
99
+ pretrained=self.clip_pretrained,
100
+ device=actual_device
101
+ )
102
+ self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name)
103
+ self.clip_model.eval()
104
+
105
+ logger.info("OpenCLIP loaded")
106
+
107
+ if progress_callback:
108
+ progress_callback("Loading SDXL pipeline...", 60)
109
+
110
+ # Load SDXL
111
+ self.pipeline = StableDiffusionXLPipeline.from_pretrained(
112
+ self.base_model_id,
113
+ torch_dtype=torch.float16 if self.use_fp16 else torch.float32,
114
+ use_safetensors=True,
115
+ variant="fp16" if self.use_fp16 else None
116
+ )
117
+
118
+ # DPM solver for faster generation
119
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
120
+ self.pipeline.scheduler.config
121
+ )
122
+
123
+ self.pipeline = self.pipeline.to(actual_device)
124
+
125
+ if progress_callback:
126
+ progress_callback("Applying optimizations...", 90)
127
+
128
+ # Memory optimizations
129
+ try:
130
+ self.pipeline.enable_xformers_memory_efficient_attention()
131
+ logger.info("xformers enabled")
132
+ except Exception:
133
+ try:
134
+ self.pipeline.enable_attention_slicing()
135
+ logger.info("Attention slicing enabled")
136
+ except Exception:
137
+ pass
138
+
139
+ if hasattr(self.pipeline, 'enable_vae_tiling'):
140
+ self.pipeline.enable_vae_tiling()
141
+
142
+ if hasattr(self.pipeline, 'enable_vae_slicing'):
143
+ self.pipeline.enable_vae_slicing()
144
+
145
+ self.pipeline.unet.eval()
146
+ if hasattr(self.pipeline, 'vae'):
147
+ self.pipeline.vae.eval()
148
+
149
+ self.is_initialized = True
150
+
151
+ if progress_callback:
152
+ progress_callback("Models loaded!", 100)
153
+
154
+ logger.info("Background models loaded successfully")
155
+
156
+ except Exception as e:
157
+ logger.error(f"Model loading failed: {e}")
158
+ raise RuntimeError(f"Failed to load models: {str(e)}")
159
+
160
+ def analyze_image_with_clip(self, image: Image.Image) -> str:
161
+ """Analyze image using OpenCLIP"""
162
+ if not self.clip_model:
163
+ return "Unknown"
164
+
165
+ try:
166
+ # Use actual device
167
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
168
+
169
+ image_input = self.clip_preprocess(image).unsqueeze(0).to(actual_device)
170
+
171
+ categories = [
172
+ "a photo of a person",
173
+ "a photo of an animal",
174
+ "a photo of an object",
175
+ "a photo of nature",
176
+ "a photo of a building"
177
+ ]
178
+
179
+ text_inputs = self.clip_tokenizer(categories).to(actual_device)
180
+
181
+ with torch.no_grad():
182
+ image_features = self.clip_model.encode_image(image_input)
183
+ text_features = self.clip_model.encode_text(text_inputs)
184
+
185
+ image_features /= image_features.norm(dim=-1, keepdim=True)
186
+ text_features /= text_features.norm(dim=-1, keepdim=True)
187
+
188
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
189
+ best_match_idx = similarity.argmax().item()
190
+
191
+ category = categories[best_match_idx].replace("a photo of ", "")
192
+ return category
193
+
194
+ except Exception as e:
195
+ logger.error(f"CLIP analysis failed: {e}")
196
+ return "unknown"
197
+
198
+ def enhance_prompt(self, user_prompt: str, foreground_image: Image.Image) -> str:
199
+ """Smart prompt enhancement based on image analysis"""
200
+ try:
201
+ img_array = np.array(foreground_image.convert('RGB'))
202
+
203
+ # Analyze color temperature
204
+ lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
205
+ avg_b = np.mean(lab[:, :, 2])
206
+ is_warm = avg_b > 128
207
+
208
+ # Analyze brightness
209
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
210
+ avg_brightness = np.mean(gray)
211
+ is_bright = avg_brightness > 127
212
+
213
+ # Get subject type
214
+ clip_analysis = self.analyze_image_with_clip(foreground_image)
215
+ subject_type = clip_analysis
216
+
217
+ # Build lighting descriptors
218
+ if is_warm and is_bright:
219
+ lighting = "warm golden hour lighting, soft natural light"
220
+ elif is_warm and not is_bright:
221
+ lighting = "warm ambient lighting, cozy atmosphere"
222
+ elif not is_warm and is_bright:
223
+ lighting = "bright daylight, clear sky lighting"
224
+ else:
225
+ lighting = "soft diffused light, gentle shadows"
226
+
227
+ # Build atmosphere based on subject
228
+ atmosphere_map = {
229
+ "person": "professional, elegant composition",
230
+ "animal": "natural, harmonious setting",
231
+ "object": "clean product photography style",
232
+ "nature": "scenic, peaceful atmosphere",
233
+ "building": "architectural, balanced composition"
234
+ }
235
+ atmosphere = atmosphere_map.get(subject_type, "balanced composition")
236
+
237
+ quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
238
+
239
+ # Avoid conflicts
240
+ user_prompt_lower = user_prompt.lower()
241
+ if "sunset" in user_prompt_lower or "golden" in user_prompt_lower:
242
+ lighting = ""
243
+ if "dark" in user_prompt_lower or "night" in user_prompt_lower:
244
+ lighting = lighting.replace("bright", "").replace("daylight", "")
245
+
246
+ # Combine
247
+ fragments = [user_prompt]
248
+ if lighting:
249
+ fragments.append(lighting)
250
+ fragments.append(atmosphere)
251
+ fragments.append(quality_modifiers)
252
+
253
+ enhanced_prompt = ", ".join(filter(None, fragments))
254
+
255
+ logger.debug(f"Enhanced: {enhanced_prompt[:80]}...")
256
+ return enhanced_prompt
257
+
258
+ except Exception as e:
259
+ logger.warning(f"Prompt enhancement failed: {e}")
260
+ return f"{user_prompt}, high quality, detailed, photorealistic"
261
+
262
+ def _prepare_image(self, image: Image.Image) -> Image.Image:
263
+ """Prepare image for processing"""
264
+ if image.mode != 'RGB':
265
+ image = image.convert('RGB')
266
+
267
+ width, height = image.size
268
+ max_size = self.max_image_size
269
+
270
+ if width > max_size or height > max_size:
271
+ ratio = min(max_size/width, max_size/height)
272
+ new_width = int(width * ratio)
273
+ new_height = int(height * ratio)
274
+ image = image.resize((new_width, new_height), Image.LANCZOS)
275
+
276
+ width, height = image.size
277
+ new_width = (width // 8) * 8
278
+ new_height = (height // 8) * 8
279
+
280
+ if new_width != width or new_height != height:
281
+ image = image.resize((new_width, new_height), Image.LANCZOS)
282
+
283
+ return image
284
+
285
+ def generate_background(
286
+ self,
287
+ prompt: str,
288
+ width: int,
289
+ height: int,
290
+ negative_prompt: str = "blurry, low quality, distorted",
291
+ num_inference_steps: int = 25,
292
+ guidance_scale: float = 7.5
293
+ ) -> Image.Image:
294
+ """Generate background using SDXL"""
295
+ if not self.is_initialized:
296
+ raise RuntimeError("Models not loaded")
297
+
298
+ logger.info(f"Generating background: {prompt[:50]}...")
299
+
300
+ try:
301
+ # Use actual device
302
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
303
+
304
+ with torch.inference_mode():
305
+ result = self.pipeline(
306
+ prompt=prompt,
307
+ negative_prompt=negative_prompt,
308
+ width=width,
309
+ height=height,
310
+ num_inference_steps=num_inference_steps,
311
+ guidance_scale=guidance_scale,
312
+ generator=torch.Generator(device=actual_device).manual_seed(42)
313
+ )
314
+
315
+ generated_image = result.images[0]
316
+ logger.info("Background generation completed")
317
+ return generated_image
318
+
319
+ except torch.cuda.OutOfMemoryError:
320
+ logger.error("GPU memory exhausted")
321
+ self._memory_cleanup()
322
+ raise RuntimeError("GPU memory insufficient")
323
+
324
+ except Exception as e:
325
+ logger.error(f"Generation failed: {e}")
326
+ raise RuntimeError(f"Generation failed: {str(e)}")
327
+
328
+ def generate_and_combine(
329
+ self,
330
+ original_image: Image.Image,
331
+ prompt: str,
332
+ combination_mode: str = "center",
333
+ focus_mode: str = "person",
334
+ negative_prompt: str = "blurry, low quality, distorted",
335
+ num_inference_steps: int = 25,
336
+ guidance_scale: float = 7.5,
337
+ progress_callback: Optional[Callable] = None,
338
+ enable_prompt_enhancement: bool = True,
339
+ feather_radius: int = 0
340
+ ) -> Dict[str, Any]:
341
+ """
342
+ Generate background and combine with foreground.
343
+
344
+ Args:
345
+ feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
346
+
347
+ Returns dict with: combined_image, generated_scene, original_image, mask, success
348
+ """
349
+ if not self.is_initialized:
350
+ raise RuntimeError("Models not loaded")
351
+
352
+ logger.info("Starting background generation and combination...")
353
+
354
+ try:
355
+ if progress_callback:
356
+ progress_callback("Analyzing image...", 5)
357
+
358
+ # Prepare image
359
+ processed_original = self._prepare_image(original_image)
360
+ target_width, target_height = processed_original.size
361
+
362
+ if progress_callback:
363
+ progress_callback("Enhancing prompt...", 15)
364
+
365
+ # Enhance prompt
366
+ if enable_prompt_enhancement:
367
+ enhanced_prompt = self.enhance_prompt(prompt, processed_original)
368
+ else:
369
+ enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
370
+
371
+ enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos"
372
+
373
+ if progress_callback:
374
+ progress_callback("Generating background...", 30)
375
+
376
+ # Generate background
377
+ generated_background = self.generate_background(
378
+ prompt=enhanced_prompt,
379
+ width=target_width,
380
+ height=target_height,
381
+ negative_prompt=enhanced_negative,
382
+ num_inference_steps=num_inference_steps,
383
+ guidance_scale=guidance_scale
384
+ )
385
+
386
+ if progress_callback:
387
+ progress_callback("Creating mask...", 80)
388
+
389
+ # Generate mask
390
+ logger.info("Generating mask...")
391
+ combination_mask = self.mask_generator.create_gradient_based_mask(
392
+ processed_original,
393
+ combination_mode,
394
+ focus_mode
395
+ )
396
+
397
+ if progress_callback:
398
+ progress_callback("Blending images...", 90)
399
+
400
+ # Blend images with feather_radius
401
+ logger.info("Blending images...")
402
+ combined_image = self.image_blender.simple_blend_images(
403
+ processed_original,
404
+ generated_background,
405
+ combination_mask,
406
+ feather_radius=feather_radius
407
+ )
408
+
409
+ # Cleanup
410
+ self._memory_cleanup()
411
+
412
+ if progress_callback:
413
+ progress_callback("Complete!", 100)
414
+
415
+ logger.info("Background generation completed successfully")
416
+
417
+ # Build result dict (always include mask for diagnostics)
418
+ return {
419
+ "combined_image": combined_image,
420
+ "generated_scene": generated_background,
421
+ "original_image": processed_original,
422
+ "mask": combination_mask,
423
+ "success": True
424
+ }
425
+
426
+ except Exception as e:
427
+ logger.error(f"Generation failed: {e}")
428
+ self._memory_cleanup()
429
+ return {
430
+ "success": False,
431
+ "error": str(e)
432
+ }
FlowFacade.py CHANGED
@@ -26,29 +26,7 @@ class FlowFacade:
26
  self.text_processor = TextProcessor(resource_manager=None)
27
  print("✓ DeltaFlow initialized")
28
 
29
- def _calculate_gpu_duration(self, image: Image.Image, duration_seconds: float,
30
- num_inference_steps: int, enable_prompt_expansion: bool, **kwargs) -> int:
31
- BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
32
- BASE_STEP_DURATION = 8
33
-
34
- resized_image = self.video_engine.resize_image(image)
35
- width, height = resized_image.width, resized_image.height
36
- frames = self.video_engine.get_num_frames(duration_seconds)
37
-
38
- factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
39
- step_duration = BASE_STEP_DURATION * factor ** 1.5
40
- total_duration = int(num_inference_steps) * step_duration
41
-
42
- # Add overhead for first-time model loading
43
- if not self.video_engine.is_loaded:
44
- total_duration += 150
45
-
46
- if enable_prompt_expansion:
47
- total_duration += 40
48
-
49
- return max(int(total_duration), 300)
50
-
51
- @spaces.GPU(duration=_calculate_gpu_duration)
52
  def generate_video_from_image(self, image: Image.Image, user_instruction: str,
53
  duration_seconds: float = 3.0, num_inference_steps: int = 4,
54
  guidance_scale: float = 1.0, guidance_scale_2: float = 1.0,
 
26
  self.text_processor = TextProcessor(resource_manager=None)
27
  print("✓ DeltaFlow initialized")
28
 
29
+ @spaces.GPU(duration=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate_video_from_image(self, image: Image.Image, user_instruction: str,
31
  duration_seconds: float = 3.0, num_inference_steps: int = 4,
32
  guidance_scale: float = 1.0, guidance_scale_2: float = 1.0,
ResourceManager.py CHANGED
@@ -1,9 +1,3 @@
1
- # %%writefile RescourceManager.py
2
- """
3
- DeltaFlow - Resource Manager
4
- Handles GPU memory allocation, deallocation, and cache management
5
- """
6
-
7
  import gc
8
  import torch
9
  from typing import Optional
 
 
 
 
 
 
 
1
  import gc
2
  import torch
3
  from typing import Optional
TextProcessor.py CHANGED
@@ -1,10 +1,3 @@
1
- # %%writefile text_processor.py
2
- """
3
- DeltaFlow - Text Processor
4
- Handles semantic expansion using Qwen2.5-0.5B-Instruct
5
- Converts brief instructions into detailed motion descriptions
6
- """
7
-
8
  import gc
9
  import traceback
10
  from typing import Optional
 
 
 
 
 
 
 
 
1
  import gc
2
  import traceback
3
  from typing import Optional
VideoEngine_optimized.py CHANGED
@@ -1,10 +1,3 @@
1
- """
2
- DeltaFlow - Video Engine (FP8 Optimized)
3
- Ultra-fast Image-to-Video generation using Wan2.2-I2V-A14B
4
- Features: Lightning LoRA + FP8 Quantization
5
- ~70-90s inference (vs 150s baseline)
6
- """
7
-
8
  import warnings
9
  warnings.filterwarnings('ignore', category=FutureWarning)
10
  warnings.filterwarnings('ignore', category=DeprecationWarning)
 
 
 
 
 
 
 
 
1
  import warnings
2
  warnings.filterwarnings('ignore', category=FutureWarning)
3
  warnings.filterwarnings('ignore', category=DeprecationWarning)
app.py CHANGED
@@ -15,6 +15,7 @@ import ftfy
15
  import sentencepiece
16
 
17
  from FlowFacade import FlowFacade
 
18
  from ui_manager import UIManager
19
 
20
 
@@ -124,11 +125,13 @@ def main():
124
 
125
  try:
126
  facade = FlowFacade()
127
- ui = UIManager(facade)
 
 
128
  is_colab = 'google.colab' in sys.modules
129
 
130
  print("✓ Ready")
131
- ui.launch(
132
  share=is_colab,
133
  server_name="0.0.0.0",
134
  server_port=None,
 
15
  import sentencepiece
16
 
17
  from FlowFacade import FlowFacade
18
+ from BackgroundEngine import BackgroundEngine
19
  from ui_manager import UIManager
20
 
21
 
 
125
 
126
  try:
127
  facade = FlowFacade()
128
+ background_engine = BackgroundEngine()
129
+ ui_manager = UIManager(facade, background_engine)
130
+ interface = ui_manager.create_interface()
131
  is_colab = 'google.colab' in sys.modules
132
 
133
  print("✓ Ready")
134
+ interface.launch(
135
  share=is_colab,
136
  server_name="0.0.0.0",
137
  server_port=None,
css_style.py CHANGED
@@ -1,5 +1,8 @@
1
  DELTAFLOW_CSS = """
2
- /* Global Light Theme */
 
 
 
3
  :root {
4
  --primary-bg: #f8f9fa;
5
  --secondary-bg: #ffffff;
@@ -11,9 +14,12 @@ DELTAFLOW_CSS = """
11
  --accent-hover: #4f46e5;
12
  --success-color: #10b981;
13
  --error-color: #ef4444;
 
14
  --shadow-sm: 0 2px 8px rgba(0, 0, 0, 0.08);
15
  --shadow-md: 0 4px 16px rgba(0, 0, 0, 0.12);
16
  --shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.16);
 
 
17
  }
18
 
19
  /* Main Container */
@@ -276,4 +282,99 @@ video {
276
  max-width: 1200px !important;
277
  margin: 0 auto !important;
278
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  """
 
1
  DELTAFLOW_CSS = """
2
+ /* Import professional fonts */
3
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
4
+
5
+ /* Global Light Theme - Combined VividFlow & SceneWeaver */
6
  :root {
7
  --primary-bg: #f8f9fa;
8
  --secondary-bg: #ffffff;
 
14
  --accent-hover: #4f46e5;
15
  --success-color: #10b981;
16
  --error-color: #ef4444;
17
+ --warning-color: #f59e0b;
18
  --shadow-sm: 0 2px 8px rgba(0, 0, 0, 0.08);
19
  --shadow-md: 0 4px 16px rgba(0, 0, 0, 0.12);
20
  --shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.16);
21
+ --radius-md: 8px;
22
+ --radius-lg: 12px;
23
  }
24
 
25
  /* Main Container */
 
282
  max-width: 1200px !important;
283
  margin: 0 auto !important;
284
  }
285
+
286
+ /* ==== SceneWeaver Background Generation Styles ==== */
287
+
288
+ /* Feature Card - Background Generation Tab */
289
+ .feature-card {
290
+ background: var(--card-bg) !important;
291
+ border: 1px solid var(--border-color) !important;
292
+ border-radius: var(--radius-lg) !important;
293
+ padding: 1.5rem !important;
294
+ box-shadow: var(--shadow-md) !important;
295
+ overflow: visible !important;
296
+ transition: all 0.2s ease !important;
297
+ }
298
+
299
+ .feature-card:hover {
300
+ border-color: var(--accent-color) !important;
301
+ box-shadow: var(--shadow-lg) !important;
302
+ }
303
+
304
+ /* Scene Template Dropdown */
305
+ .template-dropdown select,
306
+ .template-dropdown input {
307
+ font-size: 0.95rem !important;
308
+ padding: 10px 14px !important;
309
+ border-radius: var(--radius-md) !important;
310
+ border: 1px solid var(--border-color) !important;
311
+ background: var(--secondary-bg) !important;
312
+ transition: all 0.2s ease !important;
313
+ }
314
+
315
+ .template-dropdown select:focus,
316
+ .template-dropdown input:focus {
317
+ border-color: var(--accent-color) !important;
318
+ box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.15) !important;
319
+ outline: none !important;
320
+ }
321
+
322
+ /* Results Gallery */
323
+ .result-gallery {
324
+ border-radius: var(--radius-lg) !important;
325
+ overflow: hidden !important;
326
+ border: 1px solid var(--border-color) !important;
327
+ box-shadow: var(--shadow-md) !important;
328
+ }
329
+
330
+ /* Secondary Button (Download, Clear) */
331
+ .secondary-button {
332
+ background: var(--secondary-bg) !important;
333
+ color: var(--accent-color) !important;
334
+ border: 1.5px solid var(--accent-color) !important;
335
+ border-radius: var(--radius-md) !important;
336
+ padding: 12px 20px !important;
337
+ font-weight: 500 !important;
338
+ transition: all 0.2s ease !important;
339
+ }
340
+
341
+ .secondary-button:hover {
342
+ background: rgba(99, 102, 241, 0.1) !important;
343
+ }
344
+
345
+ /* Dropdown positioning fix for Gradio 4.x/5.x */
346
+ .gradio-dropdown,
347
+ .gradio-dropdown > div {
348
+ position: relative !important;
349
+ }
350
+
351
+ .gradio-dropdown ul,
352
+ .gradio-dropdown [role="listbox"] {
353
+ position: absolute !important;
354
+ z-index: 9999 !important;
355
+ left: 0 !important;
356
+ top: 100% !important;
357
+ width: 100% !important;
358
+ max-height: 300px !important;
359
+ overflow-y: auto !important;
360
+ background: var(--secondary-bg) !important;
361
+ border: 1px solid var(--border-color) !important;
362
+ border-radius: var(--radius-md) !important;
363
+ box-shadow: var(--shadow-lg) !important;
364
+ margin-top: 4px !important;
365
+ }
366
+
367
+ /* Status Panel */
368
+ .status-panel {
369
+ background: var(--secondary-bg) !important;
370
+ border: 1px solid var(--border-color) !important;
371
+ border-radius: var(--radius-md) !important;
372
+ padding: 12px 16px !important;
373
+ margin: 16px 0 !important;
374
+ }
375
+
376
+ .status-ready {
377
+ color: var(--success-color) !important;
378
+ font-weight: 500 !important;
379
+ }
380
  """
image_blender.py ADDED
@@ -0,0 +1,1117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import traceback
4
+ from PIL import Image
5
+ import logging
6
+ from typing import Dict, Any, Optional, Tuple
7
+
8
+ logger = logging.getLogger(__name__)
9
+ logger.setLevel(logging.INFO)
10
+
11
+
12
+ class ImageBlender:
13
+ """
14
+ Advanced image blending with aggressive spill suppression and color replacement.
15
+
16
+ Supports two primary modes:
17
+ - Background generation: Foreground preservation with edge refinement
18
+ - Inpainting: Seamless blending with adaptive color correction
19
+
20
+ Attributes:
21
+ enable_multi_scale: Whether multi-scale edge refinement is enabled
22
+ """
23
+
24
+ EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
25
+ ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
26
+ DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
27
+ FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value for strong protection
28
+ BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background detection
29
+
30
+ # Inpainting-specific parameters
31
+ INPAINT_FEATHER_SCALE = 1.2 # Scale factor for inpainting feathering
32
+ INPAINT_COLOR_BLEND_RADIUS = 10 # Radius for color adaptation zone
33
+
34
+ def __init__(self, enable_multi_scale: bool = True):
35
+ """
36
+ Initialize ImageBlender.
37
+
38
+ Parameters
39
+ ----------
40
+ enable_multi_scale : bool
41
+ Whether to enable multi-scale edge refinement (default True)
42
+ """
43
+ self.enable_multi_scale = enable_multi_scale
44
+ self._debug_info = {}
45
+ self._adaptive_strength_map = None
46
+
47
+ def _erode_mask_edges(
48
+ self,
49
+ mask_array: np.ndarray,
50
+ erosion_pixels: int = 2
51
+ ) -> np.ndarray:
52
+ """
53
+ Erode mask edges to remove contaminated boundary pixels.
54
+
55
+ This removes the outermost pixels of the foreground mask where
56
+ color contamination from the original background is most likely.
57
+
58
+ Args:
59
+ mask_array: Input mask as numpy array (uint8, 0-255)
60
+ erosion_pixels: Number of pixels to erode (default 2)
61
+
62
+ Returns:
63
+ Eroded mask array (uint8)
64
+ """
65
+ if erosion_pixels <= 0:
66
+ return mask_array
67
+
68
+ # Use elliptical kernel for natural-looking erosion
69
+ kernel_size = max(2, erosion_pixels)
70
+ kernel = cv2.getStructuringElement(
71
+ cv2.MORPH_ELLIPSE,
72
+ (kernel_size, kernel_size)
73
+ )
74
+
75
+ # Apply erosion
76
+ eroded = cv2.erode(mask_array, kernel, iterations=1)
77
+
78
+ # Slight blur to smooth the eroded edges
79
+ eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
80
+
81
+ logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
82
+ return eroded
83
+
84
+ def _binarize_edge_alpha(
85
+ self,
86
+ alpha: np.ndarray,
87
+ mask_array: np.ndarray,
88
+ orig_array: np.ndarray,
89
+ threshold: float = 0.45
90
+ ) -> np.ndarray:
91
+ """
92
+ Binarize semi-transparent edge pixels to eliminate color bleeding.
93
+
94
+ Semi-transparent pixels at edges cause visible contamination because
95
+ they blend the original (potentially dark) foreground with the new
96
+ background. This method forces edge pixels to be either fully opaque
97
+ or fully transparent.
98
+
99
+ Args:
100
+ alpha: Current alpha channel (float32, 0.0-1.0)
101
+ mask_array: Original mask array (uint8, 0-255)
102
+ orig_array: Original foreground image array (uint8, RGB)
103
+ threshold: Alpha threshold for binarization decision (default 0.45)
104
+
105
+ Returns:
106
+ Modified alpha array with binarized edges (float32)
107
+ """
108
+ # Identify semi-transparent edge zone (not fully opaque, not fully transparent)
109
+ edge_zone = (alpha > 0.05) & (alpha < 0.95)
110
+
111
+ if not np.any(edge_zone):
112
+ return alpha
113
+
114
+ # Calculate local foreground luminance for adaptive thresholding
115
+ gray = np.mean(orig_array, axis=2)
116
+
117
+ # For dark foreground pixels, use slightly higher threshold
118
+ # to preserve more of the dark subject
119
+ is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
120
+
121
+ # Create adaptive threshold map
122
+ adaptive_threshold = np.full_like(alpha, threshold)
123
+ adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
124
+
125
+ # Binarize: above threshold -> opaque, below -> transparent
126
+ alpha_binarized = alpha.copy()
127
+
128
+ # Pixels above threshold become fully opaque
129
+ make_opaque = edge_zone & (alpha > adaptive_threshold)
130
+ alpha_binarized[make_opaque] = 1.0
131
+
132
+ # Pixels below threshold become fully transparent
133
+ make_transparent = edge_zone & (alpha <= adaptive_threshold)
134
+ alpha_binarized[make_transparent] = 0.0
135
+
136
+ # Log statistics
137
+ num_opaque = np.sum(make_opaque)
138
+ num_transparent = np.sum(make_transparent)
139
+ logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
140
+
141
+ return alpha_binarized
142
+
143
+ def _apply_edge_cleanup(
144
+ self,
145
+ result_array: np.ndarray,
146
+ bg_array: np.ndarray,
147
+ alpha: np.ndarray,
148
+ cleanup_width: int = 2
149
+ ) -> np.ndarray:
150
+ """
151
+ Final cleanup pass to remove any remaining edge artifacts.
152
+
153
+ Detects remaining semi-transparent edges and replaces them with
154
+ either pure foreground or pure background colors.
155
+
156
+ Args:
157
+ result_array: Current blended result (uint8, RGB)
158
+ bg_array: Background image array (uint8, RGB)
159
+ alpha: Final alpha channel (float32, 0.0-1.0)
160
+ cleanup_width: Width of edge zone to clean (default 2)
161
+
162
+ Returns:
163
+ Cleaned result array (uint8)
164
+ """
165
+ # Find edge pixels that might still have artifacts
166
+ # These are pixels with alpha close to but not exactly 0 or 1
167
+ residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
168
+
169
+ if not np.any(residual_edge):
170
+ return result_array
171
+
172
+ result_cleaned = result_array.copy()
173
+
174
+ # For residual edge pixels, snap to nearest pure state
175
+ snap_to_bg = residual_edge & (alpha < 0.5)
176
+ snap_to_fg = residual_edge & (alpha >= 0.5)
177
+
178
+ # Replace with background
179
+ result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
180
+
181
+ # For foreground, keep original but ensure no blending artifacts
182
+ # (already handled by the blend, so no action needed for snap_to_fg)
183
+
184
+ num_cleaned = np.sum(residual_edge)
185
+ if num_cleaned > 0:
186
+ logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
187
+
188
+ return result_cleaned
189
+
190
+ def _remove_background_color_contamination(
191
+ self,
192
+ image_array: np.ndarray,
193
+ mask_array: np.ndarray,
194
+ orig_bg_color_lab: np.ndarray,
195
+ tolerance: float = 30.0
196
+ ) -> np.ndarray:
197
+ """
198
+ Remove original background color contamination from foreground pixels.
199
+
200
+ Scans the foreground area for pixels that match the original background
201
+ color and replaces them with nearby clean foreground colors.
202
+
203
+ Args:
204
+ image_array: Foreground image array (uint8, RGB)
205
+ mask_array: Mask array (uint8, 0-255)
206
+ orig_bg_color_lab: Original background color in Lab space
207
+ tolerance: DeltaE tolerance for detecting contaminated pixels
208
+
209
+ Returns:
210
+ Cleaned image array (uint8)
211
+ """
212
+ # Convert to Lab for color comparison
213
+ image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
214
+
215
+ # Only process foreground pixels (mask > 50)
216
+ foreground_mask = mask_array > 50
217
+
218
+ if not np.any(foreground_mask):
219
+ return image_array
220
+
221
+ # Calculate deltaE from original background color for all pixels
222
+ delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
223
+ delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
224
+ delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
225
+ delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
226
+
227
+ # Find contaminated pixels: in foreground but color similar to original background
228
+ contaminated = foreground_mask & (delta_e < tolerance)
229
+
230
+ if not np.any(contaminated):
231
+ logger.debug("No background color contamination detected in foreground")
232
+ return image_array
233
+
234
+ num_contaminated = np.sum(contaminated)
235
+ logger.info(f"Found {num_contaminated} pixels with background color contamination")
236
+
237
+ # Create output array
238
+ result = image_array.copy()
239
+
240
+ # For contaminated pixels, use inpainting to replace with surrounding colors
241
+ inpaint_mask = contaminated.astype(np.uint8) * 255
242
+
243
+ try:
244
+ # Use inpainting to fill contaminated areas with surrounding foreground colors
245
+ result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
246
+ logger.info(f"Inpainted {num_contaminated} contaminated pixels")
247
+ except Exception as e:
248
+ logger.warning(f"Inpainting failed: {e}, using median filter fallback")
249
+ # Fallback: apply median filter to contaminated areas
250
+ median_filtered = cv2.medianBlur(image_array, 5)
251
+ result[contaminated] = median_filtered[contaminated]
252
+
253
+ return result
254
+
255
+ def _protect_foreground_core(
256
+ self,
257
+ result_array: np.ndarray,
258
+ orig_array: np.ndarray,
259
+ mask_array: np.ndarray,
260
+ protection_threshold: int = 140
261
+ ) -> np.ndarray:
262
+ """
263
+ Strongly protect core foreground pixels from any background influence.
264
+
265
+ For pixels with high mask confidence, directly use the original foreground
266
+ color without any blending, ensuring faces and bodies are not affected.
267
+
268
+ Args:
269
+ result_array: Current blended result (uint8, RGB)
270
+ orig_array: Original foreground image (uint8, RGB)
271
+ mask_array: Mask array (uint8, 0-255)
272
+ protection_threshold: Mask value above which pixels are fully protected
273
+
274
+ Returns:
275
+ Protected result array (uint8)
276
+ """
277
+ # Identify strongly protected foreground pixels
278
+ strong_foreground = mask_array >= protection_threshold
279
+
280
+ if not np.any(strong_foreground):
281
+ return result_array
282
+
283
+ # For these pixels, use original foreground color directly
284
+ result_protected = result_array.copy()
285
+ result_protected[strong_foreground] = orig_array[strong_foreground]
286
+
287
+ num_protected = np.sum(strong_foreground)
288
+ logger.info(f"Protected {num_protected} core foreground pixels from background influence")
289
+
290
+ return result_protected
291
+
292
+ def multi_scale_edge_refinement(
293
+ self,
294
+ original_image: Image.Image,
295
+ background_image: Image.Image,
296
+ mask: Image.Image
297
+ ) -> Image.Image:
298
+ """
299
+ Multi-scale edge refinement for better edge quality.
300
+ Uses image pyramid to handle edges at different scales.
301
+
302
+ Args:
303
+ original_image: Foreground PIL Image
304
+ background_image: Background PIL Image
305
+ mask: Current mask PIL Image
306
+
307
+ Returns:
308
+ Refined mask PIL Image
309
+ """
310
+ logger.info("🔍 Starting multi-scale edge refinement...")
311
+
312
+ try:
313
+ # Convert to numpy arrays
314
+ orig_array = np.array(original_image.convert('RGB'))
315
+ mask_array = np.array(mask).astype(np.float32)
316
+ height, width = mask_array.shape
317
+
318
+ # Define scales for pyramid
319
+ scales = [1.0, 0.5, 0.25] # Original, half, quarter
320
+ scale_masks = []
321
+ scale_complexities = []
322
+
323
+ # Convert to grayscale for edge detection
324
+ gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
325
+
326
+ for scale in scales:
327
+ if scale == 1.0:
328
+ scaled_gray = gray
329
+ scaled_mask = mask_array
330
+ else:
331
+ new_h = int(height * scale)
332
+ new_w = int(width * scale)
333
+ scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
334
+ scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
335
+
336
+ # Compute local complexity using gradient standard deviation
337
+ sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
338
+ sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
339
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
340
+
341
+ # Calculate local complexity in 5x5 regions
342
+ kernel_size = 5
343
+ complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
344
+
345
+ # Resize back to original size
346
+ if scale != 1.0:
347
+ scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
348
+ complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
349
+
350
+ scale_masks.append(scaled_mask)
351
+ scale_complexities.append(complexity)
352
+
353
+ # Compute weights based on complexity
354
+ # High complexity -> use high resolution mask
355
+ # Low complexity -> use low resolution mask (smoother)
356
+ weights = np.zeros((len(scales), height, width), dtype=np.float32)
357
+
358
+ # Normalize complexities
359
+ max_complexity = max(c.max() for c in scale_complexities) + 1e-6
360
+ normalized_complexities = [c / max_complexity for c in scale_complexities]
361
+
362
+ # Weight assignment: higher complexity at each scale means that scale is more reliable
363
+ for i, complexity in enumerate(normalized_complexities):
364
+ if i == 0: # High resolution - prefer for high complexity regions
365
+ weights[i] = complexity
366
+ elif i == 1: # Medium resolution - moderate complexity
367
+ weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
368
+ else: # Low resolution - prefer for low complexity regions
369
+ weights[i] = 1 - complexity
370
+
371
+ # Normalize weights so they sum to 1 at each pixel
372
+ weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
373
+ weights = weights / weight_sum
374
+
375
+ # Weighted blend of masks from different scales
376
+ refined_mask = np.zeros((height, width), dtype=np.float32)
377
+ for i, mask_i in enumerate(scale_masks):
378
+ refined_mask += weights[i] * mask_i
379
+
380
+ # Clip and convert to uint8
381
+ refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
382
+
383
+ logger.info("✅ Multi-scale edge refinement completed")
384
+ return Image.fromarray(refined_mask, mode='L')
385
+
386
+ except Exception as e:
387
+ logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
388
+ return mask
389
+
390
+ def simple_blend_images(
391
+ self,
392
+ original_image: Image.Image,
393
+ background_image: Image.Image,
394
+ combination_mask: Image.Image,
395
+ use_multi_scale: Optional[bool] = None,
396
+ feather_radius: int = 0
397
+ ) -> Image.Image:
398
+ """
399
+ Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
400
+
401
+ Args:
402
+ original_image: Foreground PIL Image
403
+ background_image: Background PIL Image
404
+ combination_mask: Mask PIL Image (L mode)
405
+ use_multi_scale: Override for multi-scale refinement (None = use class default)
406
+ feather_radius: Gaussian blur radius for mask feathering (0 = disabled, default behavior)
407
+
408
+ Returns:
409
+ Blended PIL Image
410
+ """
411
+ logger.info("🎨 Starting advanced image blending process...")
412
+
413
+ # Apply multi-scale edge refinement if enabled
414
+ should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
415
+ if should_use_multi_scale:
416
+ combination_mask = self.multi_scale_edge_refinement(
417
+ original_image, background_image, combination_mask
418
+ )
419
+
420
+ # Convert to numpy arrays
421
+ orig_array = np.array(original_image, dtype=np.uint8)
422
+ bg_array = np.array(background_image, dtype=np.uint8)
423
+ mask_array = np.array(combination_mask, dtype=np.uint8)
424
+
425
+ # Apply feathering if requested
426
+ if feather_radius > 0:
427
+ kernel_size = feather_radius * 2 + 1
428
+ mask_array = cv2.GaussianBlur(
429
+ mask_array,
430
+ (kernel_size, kernel_size),
431
+ feather_radius / 2.0
432
+ )
433
+ logger.info(f"📐 Mask feathering applied: radius={feather_radius}, kernel={kernel_size}x{kernel_size}")
434
+
435
+ logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
436
+ logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
437
+
438
+ # === NEW: Apply mask erosion to remove contaminated edge pixels ===
439
+ mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
440
+ logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
441
+
442
+ # Enhanced parameters for better spill suppression
443
+ RING_WIDTH_PX = 4 # Increased ring width for better coverage
444
+ SPILL_STRENGTH = 0.85 # Stronger spill suppression
445
+ L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
446
+ DELTAE_THRESHOLD = 18 # More aggressive contamination detection
447
+ HARD_EDGE_PROTECT = True # Black edge protection
448
+ INPAINT_FALLBACK = True # inpaint fallback repair
449
+ MULTI_PASS_CORRECTION = True # Enable multi-pass correction
450
+
451
+ # Estimate original background color and foreground representative color ===
452
+ height, width = orig_array.shape[:2]
453
+
454
+ # Take 15px from each side to estimate original background color
455
+ edge_width = 15
456
+ border_pixels = []
457
+
458
+ # Collect border pixels (excluding foreground areas)
459
+ border_mask = np.zeros((height, width), dtype=bool)
460
+ border_mask[:edge_width, :] = True # Top edge
461
+ border_mask[-edge_width:, :] = True # Bottom edge
462
+ border_mask[:, :edge_width] = True # Left edge
463
+ border_mask[:, -edge_width:] = True # Right edge
464
+
465
+ # Exclude foreground areas
466
+ fg_binary = mask_array > 50
467
+ border_mask = border_mask & (~fg_binary)
468
+
469
+ if np.any(border_mask):
470
+ border_pixels = orig_array[border_mask].reshape(-1, 3)
471
+
472
+ # Simplified background color estimation (no sklearn dependency)
473
+ try:
474
+ if len(border_pixels) > 100:
475
+ # Use histogram to find mode colors
476
+ # Quantize RGB to coarser grid to find main colors
477
+ quantized = (border_pixels // 32) * 32 # 8-level quantization
478
+
479
+ # Find most frequent color
480
+ unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
481
+ axis=0, return_counts=True)
482
+ most_common_idx = np.argmax(counts)
483
+ orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
484
+ else:
485
+ orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
486
+ except:
487
+ # Fallback: use four corners average
488
+ corners = np.array([orig_array[0,0], orig_array[0,-1],
489
+ orig_array[-1,0], orig_array[-1,-1]])
490
+ orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
491
+ else:
492
+ orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
493
+
494
+ # Convert to Lab space
495
+ orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
496
+ logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
497
+
498
+ # Remove original background color contamination from foreground
499
+ orig_array = self._remove_background_color_contamination(
500
+ orig_array,
501
+ mask_array,
502
+ orig_bg_color_lab,
503
+ tolerance=self.BACKGROUND_COLOR_TOLERANCE
504
+ )
505
+
506
+ # Redefine trimap, optimized for cartoon characters
507
+ try:
508
+ kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
509
+
510
+ # FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
511
+ mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
512
+ fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
513
+
514
+ # RING: Use morphological gradient to redefine, ensuring only thin edge band
515
+ mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
516
+ mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
517
+
518
+ # Ensure consistent data types to avoid overflow
519
+ morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
520
+ ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
521
+
522
+ # BG: background area
523
+ bg_zone = mask_array < 30
524
+
525
+ logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
526
+
527
+ except Exception as e:
528
+ logger.error(f"❌ Trimap definition failed: {e}")
529
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
530
+ print(f"❌ TRIMAP ERROR: {e}")
531
+ print(f"Traceback: {traceback.format_exc()}")
532
+ # Fallback to simple definition
533
+ fg_core = mask_array > 200
534
+ ring_zone = (mask_array > 50) & (mask_array <= 200)
535
+ bg_zone = mask_array <= 50
536
+
537
+ # Foreground representative color: estimated from FG_CORE
538
+ if np.any(fg_core):
539
+ fg_pixels = orig_array[fg_core].reshape(-1, 3)
540
+ fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
541
+ else:
542
+ fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
543
+
544
+ fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
545
+
546
+ # Edge band spill suppression and repair
547
+ if np.any(ring_zone):
548
+ # Convert to Lab space
549
+ orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
550
+ orig_array_working = orig_array.copy().astype(np.float32)
551
+
552
+ # ΔE detect contaminated pixels
553
+ ring_pixels_lab = orig_lab[ring_zone]
554
+
555
+ # Calculate ΔE with original background color (simplified version)
556
+ delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
557
+ delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
558
+ delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
559
+ delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
560
+
561
+ # Contaminated pixel mask
562
+ contaminated_mask = delta_e < DELTAE_THRESHOLD
563
+
564
+ if np.any(contaminated_mask):
565
+ # Calculate adaptive strength based on delta_e for each pixel
566
+ # Pixels closer to background color get stronger correction
567
+ contaminated_delta_e = delta_e[contaminated_mask]
568
+
569
+ # Adaptive strength formula: inverse relationship with delta_e
570
+ # Pixels very close to bg color (low delta_e) -> strong correction
571
+ # Pixels further from bg color (high delta_e) -> lighter correction
572
+ adaptive_strength = SPILL_STRENGTH * np.maximum(
573
+ 0.0,
574
+ 1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
575
+ )
576
+
577
+ # Clamp adaptive strength to reasonable range (30% - 100% of base strength)
578
+ min_strength = SPILL_STRENGTH * 0.3
579
+ adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
580
+
581
+ # Store for debug visualization
582
+ self._adaptive_strength_map = np.zeros_like(delta_e)
583
+ self._adaptive_strength_map[contaminated_mask] = adaptive_strength
584
+
585
+ logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
586
+
587
+ # Chroma vector deprojection
588
+ bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
589
+ bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
590
+
591
+ # Color correction for contaminated pixels
592
+ contaminated_pixels = ring_pixels_lab[contaminated_mask]
593
+
594
+ # Remove background chroma component with adaptive strength (per-pixel)
595
+ pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
596
+ projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
597
+
598
+ # Apply adaptive strength per pixel
599
+ adaptive_strength_2d = adaptive_strength[:, np.newaxis]
600
+ corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
601
+
602
+ # Converge toward foreground representative color with adaptive strength
603
+ convergence_factor = adaptive_strength_2d * 0.6
604
+ corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
605
+ fg_rep_color_lab[1:3] * convergence_factor)
606
+
607
+ # Adaptive luminance matching
608
+ adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
609
+ corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
610
+ fg_rep_color_lab[0] * adaptive_l_strength)
611
+
612
+ # Update Lab values
613
+ ring_pixels_lab[contaminated_mask, 0] = corrected_l
614
+ ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
615
+
616
+ # Write back to original image
617
+ orig_lab[ring_zone] = ring_pixels_lab
618
+
619
+ # Dark edge protection
620
+ if HARD_EDGE_PROTECT:
621
+ gray = np.mean(orig_array, axis=2)
622
+ # Detect dark and high gradient areas
623
+ sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
624
+ sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
625
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
626
+
627
+ dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
628
+ # Protect these areas from excessive modification, copy directly from original
629
+ if np.any(dark_edge_zone):
630
+ orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
631
+
632
+ # Multi-pass correction for stubborn spill
633
+ if MULTI_PASS_CORRECTION:
634
+ # Second pass for remaining contamination
635
+ ring_pixels_lab_pass2 = orig_lab[ring_zone]
636
+ delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
637
+ delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
638
+ delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
639
+ delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
640
+
641
+ still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
642
+
643
+ if np.any(still_contaminated):
644
+ # Apply stronger correction to remaining contaminated pixels
645
+ remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
646
+
647
+ # More aggressive chroma neutralization
648
+ remaining_chroma = remaining_pixels[:, 1:3]
649
+ neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
650
+
651
+ # Stronger luminance matching
652
+ neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
653
+
654
+ ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
655
+ ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
656
+ orig_lab[ring_zone] = ring_pixels_lab_pass2
657
+
658
+ # Convert back to RGB
659
+ orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
660
+ orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
661
+
662
+ # inpaint fallback repair
663
+ if INPAINT_FALLBACK:
664
+ # inpaint still contaminated outermost pixels
665
+ final_contaminated = ring_zone.copy()
666
+
667
+ # Check if there's still contamination after repair
668
+ final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
669
+ final_ring_lab = final_lab[ring_zone]
670
+ final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
671
+ final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
672
+ final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
673
+ final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
674
+
675
+ still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
676
+ if np.any(still_contaminated):
677
+ # Create inpaint mask
678
+ inpaint_mask = np.zeros((height, width), dtype=np.uint8)
679
+ ring_coords = np.where(ring_zone)
680
+ inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
681
+ inpaint_mask[inpaint_coords] = 255
682
+
683
+ # Execute inpaint
684
+ try:
685
+ orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
686
+ except:
687
+ # Fallback: directly cover with foreground representative color
688
+ orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
689
+
690
+ orig_array = orig_array_corrected
691
+
692
+ # === Linear space blending (keep original logic) ===
693
+ def srgb_to_linear(img):
694
+ img_f = img.astype(np.float32) / 255.0
695
+ return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
696
+
697
+ def linear_to_srgb(img):
698
+ img_clipped = np.clip(img, 0, 1)
699
+ return np.where(img_clipped <= 0.0031308,
700
+ 12.92 * img_clipped,
701
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055)
702
+
703
+ orig_linear = srgb_to_linear(orig_array)
704
+ bg_linear = srgb_to_linear(bg_array)
705
+
706
+ # Cartoon-optimized Alpha calculation
707
+ alpha = mask_array.astype(np.float32) / 255.0
708
+
709
+ # Core foreground region - fully opaque
710
+ alpha[fg_core] = 1.0
711
+
712
+ # Background region - fully transparent
713
+ alpha[bg_zone] = 0.0
714
+
715
+ # [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
716
+ high_confidence_pixels = mask_array >= 160
717
+ alpha[high_confidence_pixels] = 1.0
718
+ logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
719
+
720
+ # Ring area can be dehaloed, but doesn't affect already set high confidence pixels
721
+ ring_without_high_conf = ring_zone & (~high_confidence_pixels)
722
+ alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
723
+
724
+ # Retain existing black outline/strong edge protection
725
+ orig_gray = np.mean(orig_array, axis=2)
726
+
727
+ # Detect strong edge areas
728
+ sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
729
+ sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
730
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
731
+
732
+ # Black outline/strong edge protection: nearly fully opaque
733
+ black_edge_threshold = 60 # black edge threshold
734
+ gradient_threshold = 25 # gradient threshold
735
+ strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
736
+ alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
737
+
738
+ logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
739
+
740
+ # Apply edge alpha binarization to eliminate semi-transparent artifacts
741
+ alpha = self._binarize_edge_alpha(
742
+ alpha,
743
+ mask_array,
744
+ orig_array,
745
+ threshold=self.ALPHA_BINARIZE_THRESHOLD
746
+ )
747
+
748
+ # Final blending
749
+ alpha_3d = alpha[:, :, np.newaxis]
750
+ result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
751
+ result_srgb = linear_to_srgb(result_linear)
752
+ result_array = (result_srgb * 255).astype(np.uint8)
753
+
754
+ # Final edge cleanup pass
755
+ result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
756
+
757
+ # Protect core foreground from any background influence
758
+ # This ensures faces and bodies retain original colors
759
+ result_array = self._protect_foreground_core(
760
+ result_array,
761
+ np.array(original_image, dtype=np.uint8), # Use original unprocessed image
762
+ mask_array,
763
+ protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
764
+ )
765
+
766
+ # Store debug information (for debug output)
767
+ self._debug_info = {
768
+ 'orig_bg_color_rgb': orig_bg_color_rgb,
769
+ 'fg_rep_color_rgb': fg_rep_color_rgb,
770
+ 'orig_bg_color_lab': orig_bg_color_lab,
771
+ 'fg_rep_color_lab': fg_rep_color_lab,
772
+ 'ring_zone': ring_zone,
773
+ 'fg_core': fg_core,
774
+ 'alpha_final': alpha
775
+ }
776
+
777
+ return Image.fromarray(result_array)
778
+
779
+ def create_debug_images(
780
+ self,
781
+ original_image: Image.Image,
782
+ generated_background: Image.Image,
783
+ combination_mask: Image.Image,
784
+ combined_image: Image.Image
785
+ ) -> Dict[str, Image.Image]:
786
+ """
787
+ Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
788
+ """
789
+ debug_images = {}
790
+
791
+ # Final Mask grayscale
792
+ debug_images["mask_gray"] = combination_mask.convert('L')
793
+
794
+ # Alpha Heatmap
795
+ mask_array = np.array(combination_mask.convert('L'))
796
+ heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
797
+ heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
798
+ debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
799
+
800
+ # Ring visualization overlay - show ring areas on original image
801
+ if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
802
+ ring_zone = self._debug_info['ring_zone']
803
+ orig_array = np.array(original_image)
804
+ ring_overlay = orig_array.copy()
805
+
806
+ # Mark ring areas with red semi-transparent overlay
807
+ ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
808
+ debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
809
+ else:
810
+ # If no ring information, use original image
811
+ debug_images["ring_visualization"] = original_image
812
+
813
+ # Adaptive strength heatmap - visualize per-pixel correction strength
814
+ if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
815
+ # Normalize adaptive strength to 0-255 for visualization
816
+ strength_map = self._adaptive_strength_map
817
+ if strength_map.max() > 0:
818
+ normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
819
+ else:
820
+ normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
821
+
822
+ # Apply colormap
823
+ strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
824
+ strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
825
+ debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
826
+
827
+ return debug_images
828
+
829
+ # INPAINTING-SPECIFIC BLENDING METHODS
830
+ def blend_inpainting(
831
+ self,
832
+ original: Image.Image,
833
+ generated: Image.Image,
834
+ mask: Image.Image,
835
+ feather_radius: int = 8,
836
+ apply_color_correction: bool = True
837
+ ) -> Image.Image:
838
+ """
839
+ Blend inpainted region with original image.
840
+
841
+ Specialized blending for inpainting that focuses on seamless integration
842
+ rather than foreground protection. Performs blending in linear color space
843
+ with optional adaptive color correction at boundaries.
844
+
845
+ Parameters
846
+ ----------
847
+ original : PIL.Image
848
+ Original image before inpainting
849
+ generated : PIL.Image
850
+ Generated/inpainted result from the model
851
+ mask : PIL.Image
852
+ Inpainting mask (white = inpainted area)
853
+ feather_radius : int
854
+ Feathering radius for smooth transitions
855
+ apply_color_correction : bool
856
+ Whether to apply adaptive color correction at boundaries
857
+
858
+ Returns
859
+ -------
860
+ PIL.Image
861
+ Blended result
862
+ """
863
+ logger.info(f"Inpainting blend: feather={feather_radius}, color_correction={apply_color_correction}")
864
+
865
+ # Ensure same size
866
+ if generated.size != original.size:
867
+ generated = generated.resize(original.size, Image.LANCZOS)
868
+ if mask.size != original.size:
869
+ mask = mask.resize(original.size, Image.LANCZOS)
870
+
871
+ # Convert to arrays
872
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
873
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
874
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
875
+
876
+ # Apply feathering to mask
877
+ if feather_radius > 0:
878
+ scaled_radius = int(feather_radius * self.INPAINT_FEATHER_SCALE)
879
+ kernel_size = scaled_radius * 2 + 1
880
+ mask_array = cv2.GaussianBlur(
881
+ mask_array,
882
+ (kernel_size, kernel_size),
883
+ scaled_radius / 2
884
+ )
885
+
886
+ # Apply adaptive color correction if enabled
887
+ if apply_color_correction:
888
+ gen_array = self._apply_inpaint_color_correction(
889
+ orig_array, gen_array, mask_array
890
+ )
891
+
892
+ # sRGB to linear conversion for accurate blending
893
+ def srgb_to_linear(img):
894
+ img_norm = img / 255.0
895
+ return np.where(
896
+ img_norm <= 0.04045,
897
+ img_norm / 12.92,
898
+ np.power((img_norm + 0.055) / 1.055, 2.4)
899
+ )
900
+
901
+ def linear_to_srgb(img):
902
+ img_clipped = np.clip(img, 0, 1)
903
+ return np.where(
904
+ img_clipped <= 0.0031308,
905
+ 12.92 * img_clipped,
906
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055
907
+ )
908
+
909
+ # Convert to linear space
910
+ orig_linear = srgb_to_linear(orig_array)
911
+ gen_linear = srgb_to_linear(gen_array)
912
+
913
+ # Alpha blending in linear space
914
+ alpha = mask_array[:, :, np.newaxis]
915
+ result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
916
+
917
+ # Convert back to sRGB
918
+ result_srgb = linear_to_srgb(result_linear)
919
+ result_array = (result_srgb * 255).astype(np.uint8)
920
+
921
+ logger.debug("Inpainting blend completed in linear color space")
922
+
923
+ return Image.fromarray(result_array)
924
+
925
+ def _apply_inpaint_color_correction(
926
+ self,
927
+ original: np.ndarray,
928
+ generated: np.ndarray,
929
+ mask: np.ndarray
930
+ ) -> np.ndarray:
931
+ """
932
+ Apply adaptive color correction to match generated region with surroundings.
933
+
934
+ Analyzes the boundary region and adjusts the generated content's
935
+ luminance and color to better match the original context.
936
+
937
+ Parameters
938
+ ----------
939
+ original : np.ndarray
940
+ Original image (float32, 0-255)
941
+ generated : np.ndarray
942
+ Generated image (float32, 0-255)
943
+ mask : np.ndarray
944
+ Blend mask (float32, 0-1)
945
+
946
+ Returns
947
+ -------
948
+ np.ndarray
949
+ Color-corrected generated image
950
+ """
951
+ # Find boundary region
952
+ mask_binary = (mask > 0.5).astype(np.uint8)
953
+ kernel = cv2.getStructuringElement(
954
+ cv2.MORPH_ELLIPSE,
955
+ (self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1, self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1)
956
+ )
957
+ dilated = cv2.dilate(mask_binary, kernel, iterations=1)
958
+ boundary_zone = (dilated > 0) & (mask < 0.3)
959
+
960
+ if not np.any(boundary_zone):
961
+ return generated
962
+
963
+ # Convert to Lab for perceptual color matching
964
+ orig_lab = cv2.cvtColor(
965
+ original.astype(np.uint8), cv2.COLOR_RGB2LAB
966
+ ).astype(np.float32)
967
+ gen_lab = cv2.cvtColor(
968
+ generated.astype(np.uint8), cv2.COLOR_RGB2LAB
969
+ ).astype(np.float32)
970
+
971
+ # Calculate statistics in boundary zone (original)
972
+ boundary_orig_l = orig_lab[boundary_zone, 0]
973
+ boundary_orig_a = orig_lab[boundary_zone, 1]
974
+ boundary_orig_b = orig_lab[boundary_zone, 2]
975
+
976
+ orig_mean_l = np.median(boundary_orig_l)
977
+ orig_mean_a = np.median(boundary_orig_a)
978
+ orig_mean_b = np.median(boundary_orig_b)
979
+
980
+ # Calculate statistics in generated inpaint region
981
+ inpaint_zone = mask > 0.5
982
+ if not np.any(inpaint_zone):
983
+ return generated
984
+
985
+ gen_inpaint_l = gen_lab[inpaint_zone, 0]
986
+ gen_inpaint_a = gen_lab[inpaint_zone, 1]
987
+ gen_inpaint_b = gen_lab[inpaint_zone, 2]
988
+
989
+ gen_mean_l = np.median(gen_inpaint_l)
990
+ gen_mean_a = np.median(gen_inpaint_a)
991
+ gen_mean_b = np.median(gen_inpaint_b)
992
+
993
+ # Calculate correction deltas
994
+ delta_l = orig_mean_l - gen_mean_l
995
+ delta_a = orig_mean_a - gen_mean_a
996
+ delta_b = orig_mean_b - gen_mean_b
997
+
998
+ # Limit correction to avoid over-adjustment
999
+ max_correction = 15
1000
+ delta_l = np.clip(delta_l, -max_correction, max_correction)
1001
+ delta_a = np.clip(delta_a, -max_correction * 0.5, max_correction * 0.5)
1002
+ delta_b = np.clip(delta_b, -max_correction * 0.5, max_correction * 0.5)
1003
+
1004
+ logger.debug(f"Color correction deltas: L={delta_l:.1f}, a={delta_a:.1f}, b={delta_b:.1f}")
1005
+
1006
+ # Apply correction with spatial falloff from boundary
1007
+ # Create distance map from boundary
1008
+ distance = cv2.distanceTransform(
1009
+ mask_binary, cv2.DIST_L2, 5
1010
+ )
1011
+ max_dist = np.max(distance)
1012
+ if max_dist > 0:
1013
+ # Correction strength falls off from boundary toward center
1014
+ correction_strength = 1.0 - np.clip(distance / (max_dist * 0.5), 0, 1)
1015
+ else:
1016
+ correction_strength = np.ones_like(distance)
1017
+
1018
+ # Apply correction to Lab channels
1019
+ corrected_lab = gen_lab.copy()
1020
+ corrected_lab[:, :, 0] += delta_l * correction_strength * 0.7
1021
+ corrected_lab[:, :, 1] += delta_a * correction_strength * 0.5
1022
+ corrected_lab[:, :, 2] += delta_b * correction_strength * 0.5
1023
+
1024
+ # Clip to valid Lab ranges
1025
+ corrected_lab[:, :, 0] = np.clip(corrected_lab[:, :, 0], 0, 255)
1026
+ corrected_lab[:, :, 1] = np.clip(corrected_lab[:, :, 1], 0, 255)
1027
+ corrected_lab[:, :, 2] = np.clip(corrected_lab[:, :, 2], 0, 255)
1028
+
1029
+ # Convert back to RGB
1030
+ corrected_rgb = cv2.cvtColor(
1031
+ corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB
1032
+ ).astype(np.float32)
1033
+
1034
+ return corrected_rgb
1035
+
1036
+ def blend_inpainting_with_guided_filter(
1037
+ self,
1038
+ original: Image.Image,
1039
+ generated: Image.Image,
1040
+ mask: Image.Image,
1041
+ feather_radius: int = 8,
1042
+ guide_radius: int = 8,
1043
+ guide_eps: float = 0.01
1044
+ ) -> Image.Image:
1045
+ """
1046
+ Blend inpainted region using guided filter for edge-aware transitions.
1047
+
1048
+ Combines standard alpha blending with guided filtering to preserve
1049
+ edges in the original image while seamlessly integrating new content.
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ original : PIL.Image
1054
+ Original image
1055
+ generated : PIL.Image
1056
+ Generated/inpainted result
1057
+ mask : PIL.Image
1058
+ Inpainting mask
1059
+ feather_radius : int
1060
+ Base feathering radius
1061
+ guide_radius : int
1062
+ Guided filter radius
1063
+ guide_eps : float
1064
+ Guided filter regularization
1065
+
1066
+ Returns
1067
+ -------
1068
+ PIL.Image
1069
+ Blended result with edge-aware transitions
1070
+ """
1071
+ logger.info("Applying guided filter inpainting blend")
1072
+
1073
+ # Ensure same size
1074
+ if generated.size != original.size:
1075
+ generated = generated.resize(original.size, Image.LANCZOS)
1076
+ if mask.size != original.size:
1077
+ mask = mask.resize(original.size, Image.LANCZOS)
1078
+
1079
+ # Convert to arrays
1080
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
1081
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
1082
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
1083
+
1084
+ # Apply base feathering
1085
+ if feather_radius > 0:
1086
+ kernel_size = feather_radius * 2 + 1
1087
+ mask_feathered = cv2.GaussianBlur(
1088
+ mask_array,
1089
+ (kernel_size, kernel_size),
1090
+ feather_radius / 2
1091
+ )
1092
+ else:
1093
+ mask_feathered = mask_array
1094
+
1095
+ # Use original image as guide for the filter
1096
+ guide = cv2.cvtColor(orig_array.astype(np.uint8), cv2.COLOR_RGB2GRAY)
1097
+ guide = guide.astype(np.float32) / 255.0
1098
+
1099
+ # Apply guided filter to the mask
1100
+ try:
1101
+ mask_guided = cv2.ximgproc.guidedFilter(
1102
+ guide=guide,
1103
+ src=mask_feathered,
1104
+ radius=guide_radius,
1105
+ eps=guide_eps
1106
+ )
1107
+ logger.debug("Guided filter applied successfully")
1108
+ except Exception as e:
1109
+ logger.warning(f"Guided filter failed: {e}, using standard feathering")
1110
+ mask_guided = mask_feathered
1111
+
1112
+ # Alpha blending
1113
+ alpha = mask_guided[:, :, np.newaxis]
1114
+ result = gen_array * alpha + orig_array * (1 - alpha)
1115
+ result = np.clip(result, 0, 255).astype(np.uint8)
1116
+
1117
+ return Image.fromarray(result)
mask_generator.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import traceback
4
+ from PIL import Image, ImageFilter, ImageDraw
5
+ import logging
6
+ from typing import Optional, Tuple
7
+ from scipy.ndimage import binary_erosion, binary_dilation
8
+ import io
9
+ import gc
10
+ import torch
11
+ from transformers import AutoModelForImageSegmentation
12
+ from torchvision import transforms
13
+ from rembg import remove, new_session
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.INFO)
17
+
18
+ class MaskGenerator:
19
+ """
20
+ Intelligent mask generation using deep learning models with traditional fallback.
21
+ Priority: BiRefNet > U²-Net (rembg) > Traditional gradient-based methods
22
+ """
23
+
24
+ def __init__(self, max_image_size: int = 1024, device: str = "auto"):
25
+ self.max_image_size = max_image_size
26
+ self.device = self._setup_device(device)
27
+
28
+ # BiRefNet model (lazy loading)
29
+ self._birefnet_model = None
30
+ self._birefnet_transform = None
31
+
32
+ # Log initialization
33
+ logger.info(f"🎭 MaskGenerator initialized on {self.device}")
34
+
35
+ def _setup_device(self, device: str) -> str:
36
+ """Setup computation device"""
37
+ if device == "auto":
38
+ if torch.cuda.is_available():
39
+ return "cuda"
40
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
41
+ return "mps"
42
+ return "cpu"
43
+ return device
44
+
45
+ def _load_birefnet_model(self) -> bool:
46
+ """
47
+ Lazy load BiRefNet model for memory efficiency.
48
+ Returns True if model loaded successfully, False otherwise.
49
+ """
50
+ if self._birefnet_model is not None:
51
+ return True
52
+
53
+ try:
54
+ logger.info("📥 Loading BiRefNet model (ZhengPeng7/BiRefNet)...")
55
+
56
+ # Load model with fp16 for memory efficiency on GPU
57
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
58
+
59
+ self._birefnet_model = AutoModelForImageSegmentation.from_pretrained(
60
+ "ZhengPeng7/BiRefNet",
61
+ trust_remote_code=True,
62
+ torch_dtype=dtype
63
+ )
64
+ self._birefnet_model.to(self.device)
65
+ self._birefnet_model.eval()
66
+
67
+ # Define preprocessing transform
68
+ self._birefnet_transform = transforms.Compose([
69
+ transforms.Resize((1024, 1024)),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
72
+ ])
73
+
74
+ logger.info("✅ BiRefNet model loaded successfully")
75
+ return True
76
+
77
+ except Exception as e:
78
+ logger.error(f"❌ Failed to load BiRefNet: {e}")
79
+ self._birefnet_model = None
80
+ self._birefnet_transform = None
81
+ return False
82
+
83
+ def _unload_birefnet_model(self):
84
+ """Unload BiRefNet model to free memory"""
85
+ if self._birefnet_model is not None:
86
+ del self._birefnet_model
87
+ self._birefnet_model = None
88
+ self._birefnet_transform = None
89
+
90
+ if torch.cuda.is_available():
91
+ torch.cuda.empty_cache()
92
+ gc.collect()
93
+ logger.info("🧹 BiRefNet model unloaded")
94
+
95
+ def apply_guided_filter(
96
+ self,
97
+ mask: np.ndarray,
98
+ guide_image: Image.Image,
99
+ radius: int = 8,
100
+ eps: float = 0.01
101
+ ) -> np.ndarray:
102
+ """
103
+ Apply guided filter to mask for edge-preserving smoothing.
104
+ Falls back to Gaussian blur if guided filter is not available.
105
+
106
+ Args:
107
+ mask: Input mask as numpy array (0-255)
108
+ guide_image: Original image to use as guide
109
+ radius: Filter radius (larger = more smoothing)
110
+ eps: Regularization parameter (smaller = more edge-preserving)
111
+
112
+ Returns:
113
+ Filtered mask as numpy array (0-255)
114
+ """
115
+ try:
116
+ # Convert guide image to grayscale
117
+ guide_gray = np.array(guide_image.convert('L')).astype(np.float32) / 255.0
118
+ mask_float = mask.astype(np.float32) / 255.0
119
+
120
+ logger.info(f"🔧 Applying guided filter (radius={radius}, eps={eps})")
121
+
122
+ # Apply guided filter
123
+ filtered = cv2.ximgproc.guidedFilter(
124
+ guide=guide_gray,
125
+ src=mask_float,
126
+ radius=radius,
127
+ eps=eps
128
+ )
129
+
130
+ # Convert back to 0-255 range
131
+ result = (np.clip(filtered, 0, 1) * 255).astype(np.uint8)
132
+ logger.info("✅ Guided filter applied successfully")
133
+ return result
134
+
135
+ except Exception as e:
136
+ logger.error(f"❌ Guided filter failed: {e}, using original mask")
137
+ return mask
138
+
139
+ def try_birefnet_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
140
+ """
141
+ Generate foreground mask using BiRefNet model.
142
+ BiRefNet provides high-quality segmentation with clean edges.
143
+
144
+ Args:
145
+ original_image: Input PIL Image
146
+
147
+ Returns:
148
+ PIL Image (L mode) mask or None if failed
149
+ """
150
+ try:
151
+ # Lazy load model
152
+ if not self._load_birefnet_model():
153
+ return None
154
+
155
+ logger.info("🤖 Starting BiRefNet foreground extraction...")
156
+ original_size = original_image.size
157
+
158
+ # Convert to RGB if needed
159
+ if original_image.mode != 'RGB':
160
+ image_rgb = original_image.convert('RGB')
161
+ else:
162
+ image_rgb = original_image
163
+
164
+ # Preprocess image
165
+ input_tensor = self._birefnet_transform(image_rgb).unsqueeze(0)
166
+
167
+ # Move to device with appropriate dtype
168
+ if self.device == "cuda":
169
+ input_tensor = input_tensor.to(self.device, dtype=torch.float16)
170
+ else:
171
+ input_tensor = input_tensor.to(self.device)
172
+
173
+ # Run inference
174
+ with torch.no_grad():
175
+ outputs = self._birefnet_model(input_tensor)
176
+
177
+ # BiRefNet outputs a list, get the final prediction
178
+ if isinstance(outputs, (list, tuple)):
179
+ pred = outputs[-1]
180
+ else:
181
+ pred = outputs
182
+
183
+ # Sigmoid to get probability map
184
+ pred = torch.sigmoid(pred)
185
+
186
+ # Convert to numpy
187
+ pred_np = pred.squeeze().cpu().numpy()
188
+
189
+ # Convert to 0-255 range
190
+ mask_array = (pred_np * 255).astype(np.uint8)
191
+
192
+ # Resize back to original size
193
+ mask_pil = Image.fromarray(mask_array, mode='L')
194
+ mask_pil = mask_pil.resize(original_size, Image.LANCZOS)
195
+ mask_array = np.array(mask_pil)
196
+
197
+ # Quality check
198
+ mean_val = mask_array.mean()
199
+ nonzero_ratio = np.count_nonzero(mask_array > 50) / mask_array.size
200
+
201
+ logger.info(f"📊 BiRefNet mask stats - Mean: {mean_val:.1f}, Coverage: {nonzero_ratio:.1%}")
202
+
203
+ if mean_val < 10:
204
+ logger.warning("⚠️ BiRefNet mask too weak, falling back")
205
+ return None
206
+
207
+ if nonzero_ratio < 0.03:
208
+ logger.warning("⚠️ BiRefNet foreground coverage too low, falling back")
209
+ return None
210
+
211
+ # Light post-processing for edge refinement
212
+ # Use morphological operations to clean up
213
+ kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
214
+ mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_CLOSE, kernel_small)
215
+
216
+ logger.info("✅ BiRefNet mask generation successful!")
217
+ return Image.fromarray(mask_array, mode='L')
218
+
219
+ except torch.cuda.OutOfMemoryError:
220
+ logger.error("❌ BiRefNet: GPU memory exhausted")
221
+ self._unload_birefnet_model()
222
+ return None
223
+
224
+ except Exception as e:
225
+ logger.error(f"❌ BiRefNet mask generation failed: {e}")
226
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
227
+ return None
228
+
229
+ def try_deep_learning_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
230
+ """
231
+ Intelligent foreground extraction with model priority:
232
+ 1. BiRefNet (best quality, clean edges)
233
+ 2. U²-Net via rembg (good fallback)
234
+ 3. Return None to trigger traditional methods
235
+
236
+ Args:
237
+ original_image: Input PIL Image
238
+
239
+ Returns:
240
+ PIL Image (L mode) mask or None if all methods failed
241
+ """
242
+ # Priority 1: Try BiRefNet first
243
+ logger.info("🤖 Attempting BiRefNet mask generation...")
244
+ birefnet_mask = self.try_birefnet_mask(original_image)
245
+ if birefnet_mask is not None:
246
+ logger.info("✅ Using BiRefNet generated mask")
247
+ return birefnet_mask
248
+
249
+ # Priority 2: Fallback to rembg (U²-Net)
250
+ logger.info("🔄 BiRefNet unavailable/failed, trying rembg...")
251
+ try:
252
+ logger.info("🤖 Starting rembg foreground extraction")
253
+
254
+ # Try u2net first (better for cartoons/objects like Snoopy)
255
+ try:
256
+ session = new_session('u2net')
257
+ logger.info("✅ Using u2net model")
258
+ except Exception as e:
259
+ logger.warning(f"u2net failed ({e}), trying u2net_human_seg")
260
+ try:
261
+ session = new_session('u2net_human_seg')
262
+ logger.info("✅ Using u2net_human_seg model")
263
+ except Exception as e2:
264
+ logger.error(f"All rembg models failed: {e2}")
265
+ return None
266
+
267
+ # Convert image to bytes for rembg
268
+ img_byte_arr = io.BytesIO()
269
+ original_image.save(img_byte_arr, format='PNG')
270
+ img_byte_arr = img_byte_arr.getvalue()
271
+ logger.info(f"📷 Image size: {len(img_byte_arr)} bytes")
272
+
273
+ # Perform background removal
274
+ result = remove(img_byte_arr, session=session)
275
+ result_img = Image.open(io.BytesIO(result)).convert('RGBA')
276
+ alpha_channel = result_img.split()[-1]
277
+ alpha_array = np.array(alpha_channel)
278
+
279
+ logger.info(f"📊 Raw alpha stats - Mean: {alpha_array.mean():.1f}, Min: {alpha_array.min()}, Max: {alpha_array.max()}")
280
+
281
+ # Step 1: Light smoothing to reduce noise but preserve edges
282
+ alpha_smoothed = cv2.GaussianBlur(alpha_array, (3, 3), 0.8)
283
+
284
+ # Step 2: Contrast stretching to utilize full range
285
+ alpha_stretched = cv2.normalize(alpha_smoothed, None, 0, 255, cv2.NORM_MINMAX)
286
+
287
+ # Step 3: CRITICAL FIX - More aggressive foreground preservation
288
+ # Instead of hard threshold, use adaptive approach
289
+
290
+ # Find the main subject area (high confidence regions)
291
+ high_confidence = alpha_stretched > 180
292
+ medium_confidence = (alpha_stretched > 60) & (alpha_stretched <= 180)
293
+ low_confidence = (alpha_stretched > 15) & (alpha_stretched <= 60)
294
+
295
+ # Create final mask with better extremity handling
296
+ final_alpha = np.zeros_like(alpha_stretched)
297
+
298
+ # High confidence areas - keep at full opacity
299
+ final_alpha[high_confidence] = 255
300
+
301
+ # Medium confidence - boost significantly
302
+ final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
303
+
304
+ # Low confidence - moderate boost (catches faint extremities)
305
+ final_alpha[low_confidence] = np.clip(alpha_stretched[low_confidence] * 2.5, 120, 199)
306
+
307
+ # Morphological operations to connect disconnected parts (hands, feet, tail)
308
+ kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
309
+ kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
310
+
311
+ # Close small gaps (helps connect separated body parts)
312
+ final_alpha = cv2.morphologyEx(final_alpha, cv2.MORPH_CLOSE, kernel_small, iterations=1)
313
+
314
+ # Light dilation to ensure nothing gets cut off
315
+ final_alpha = cv2.dilate(final_alpha, kernel_small, iterations=1)
316
+
317
+ logger.info(f"📊 Final alpha stats - Mean: {final_alpha.mean():.1f}, Min: {final_alpha.min()}, Max: {final_alpha.max()}")
318
+
319
+ # Quality check - but be more lenient for cartoon characters
320
+ if final_alpha.mean() < 10:
321
+ logger.warning("⚠️ Alpha still too weak, falling back to traditional method")
322
+ return None
323
+
324
+ # Enhanced post-processing for cartoon characters
325
+ is_cartoon = self._detect_cartoon_character(original_image, final_alpha)
326
+
327
+ if is_cartoon:
328
+ logger.info("🎭 Detected cartoon/character image, applying specialized processing")
329
+ final_alpha = self._enhance_cartoon_mask(original_image, final_alpha)
330
+
331
+ # Count non-zero pixels to ensure we have substantial foreground
332
+ foreground_pixels = np.count_nonzero(final_alpha > 50)
333
+ total_pixels = final_alpha.size
334
+ foreground_ratio = foreground_pixels / total_pixels
335
+ logger.info(f"📊 Foreground coverage: {foreground_ratio:.1%} of image")
336
+
337
+ if foreground_ratio < 0.05: # Less than 5% is probably too little
338
+ logger.warning("⚠️ Very low foreground coverage, falling back to traditional method")
339
+ return None
340
+
341
+ mask = Image.fromarray(final_alpha.astype(np.uint8), mode='L')
342
+ logger.info("✅ Enhanced rembg mask generation successful!")
343
+ return mask
344
+
345
+ except Exception as e:
346
+ logger.error(f"❌ Deep learning mask extraction failed: {e}")
347
+ return None
348
+
349
+ def _detect_cartoon_character(self, original_image: Image.Image, alpha_mask: np.ndarray) -> bool:
350
+ """
351
+ Detect if image is cartoon/line art (heuristic approach)
352
+ """
353
+ try:
354
+ img_array = np.array(original_image.convert('RGB'))
355
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
356
+
357
+ # Calculate edge density (cartoons usually have more clear edges)
358
+ edges = cv2.Canny(gray, 50, 150)
359
+ edge_density = np.count_nonzero(edges) / max(edges.size, 1) # Avoid division by zero
360
+
361
+ # Calculate color complexity (cartoons usually have fewer colors) - optimize memory usage
362
+ h, w, c = img_array.shape
363
+ if h * w > 100000: # If image is too large, resize for processing
364
+ small_img = cv2.resize(img_array, (200, 200))
365
+ else:
366
+ small_img = img_array
367
+
368
+ unique_colors = len(np.unique(small_img.reshape(-1, 3), axis=0))
369
+ total_pixels = small_img.shape[0] * small_img.shape[1]
370
+ color_simplicity = unique_colors < (total_pixels * 0.1)
371
+
372
+ # Check for obvious black outlines
373
+ dark_pixels_ratio = np.count_nonzero(gray < 50) / max(gray.size, 1) # Avoid division by zero
374
+ has_black_outline = dark_pixels_ratio > 0.05
375
+
376
+ # Comprehensive judgment: high edge density + color simplicity + black outline = likely cartoon
377
+ is_cartoon = (edge_density > 0.05) and (color_simplicity or has_black_outline)
378
+
379
+ logger.info(f"🔍 Cartoon detection - Edge density: {edge_density:.3f}, Color simplicity: {color_simplicity}, Black outline: {has_black_outline} -> Cartoon: {is_cartoon}")
380
+ return is_cartoon
381
+
382
+ except Exception as e:
383
+ logger.error(f"❌ Cartoon detection failed: {e}")
384
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
385
+ print(f"❌ CARTOON DETECTION ERROR: {e}")
386
+ print(f"Traceback: {traceback.format_exc()}")
387
+ return False
388
+
389
+ def _enhance_cartoon_mask(self, original_image: Image.Image, alpha_mask: np.ndarray) -> np.ndarray:
390
+ """
391
+ Enhanced mask processing for cartoon characters
392
+ """
393
+ try:
394
+ img_array = np.array(original_image.convert('RGB'))
395
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
396
+ enhanced_alpha = alpha_mask.copy()
397
+
398
+ # Step 1: Black outline enhancement - find black outlines and enhance their alpha
399
+ th_dark = 80 # Adjustable parameter: black threshold
400
+ black_outline = gray < th_dark
401
+
402
+ # Dilate black outline region by 1px
403
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Adjustable parameter: dilation kernel size
404
+ black_outline_dilated = cv2.dilate(black_outline.astype(np.uint8), kernel_dilate, iterations=1)
405
+
406
+ # Set black outline region alpha directly to 255
407
+ enhanced_alpha[black_outline_dilated > 0] = 255
408
+ logger.info(f"🖤 Black outline enhanced: {np.count_nonzero(black_outline_dilated)} pixels")
409
+
410
+ # Step 2: Simplified internal enhancement - process white fill areas within outlines
411
+ # Find high confidence regions (alpha ≥ 160)
412
+ high_confidence = enhanced_alpha >= 160
413
+
414
+ # Apply close operation on high confidence regions to connect separated parts
415
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Adjustable parameter: close kernel size
416
+ high_confidence_closed = cv2.morphologyEx(high_confidence.astype(np.uint8), cv2.MORPH_CLOSE, kernel_close, iterations=1)
417
+
418
+ # Simplified approach: directly enhance medium confidence regions without complex flood fill
419
+ # Find medium/low confidence regions surrounded by high confidence regions
420
+ medium_confidence = (enhanced_alpha >= 80) & (enhanced_alpha < 160)
421
+
422
+ # Dilate high confidence region to include more internal areas
423
+ kernel_dilate_internal = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
424
+ high_confidence_expanded = cv2.dilate(high_confidence_closed, kernel_dilate_internal, iterations=1)
425
+
426
+ # Medium confidence pixels within expanded high confidence areas are considered internal fill
427
+ internal_fill_regions = medium_confidence & (high_confidence_expanded > 0)
428
+
429
+ # Enhance alpha of these internal fill regions to at least 220
430
+ min_alpha_for_fill = 220 # Adjustable parameter: minimum alpha for internal fill
431
+ enhanced_alpha[internal_fill_regions] = np.maximum(enhanced_alpha[internal_fill_regions], min_alpha_for_fill)
432
+
433
+ logger.info(f"🤍 Internal fill regions enhanced: {np.count_nonzero(internal_fill_regions)} pixels")
434
+ logger.info(f"📊 Enhanced alpha stats - Mean: {enhanced_alpha.mean():.1f}, Min: {enhanced_alpha.min()}, Max: {enhanced_alpha.max()}")
435
+
436
+ return enhanced_alpha
437
+
438
+ except Exception as e:
439
+ logger.error(f"❌ Cartoon mask enhancement failed: {e}")
440
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
441
+ print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
442
+ print(f"Traceback: {traceback.format_exc()}")
443
+ return alpha_mask
444
+
445
+ def _adjust_mask_for_scene_focus(self, mask: Image.Image, original_image: Image.Image) -> Image.Image:
446
+ """
447
+ Adjust mask for scene focus mode to include nearby objects like chairs, furniture
448
+ """
449
+ try:
450
+ logger.info("🏠 Adjusting mask for scene focus mode...")
451
+
452
+ mask_array = np.array(mask)
453
+ img_array = np.array(original_image.convert('RGB'))
454
+
455
+ # Expand mask to include nearby objects
456
+ # Use larger dilation kernel to include furniture/objects
457
+ kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
458
+ expanded_mask = cv2.dilate(mask_array, kernel_large, iterations=2)
459
+
460
+ # Find contours in the expanded area to detect objects
461
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
462
+ edges = cv2.Canny(gray, 30, 100)
463
+
464
+ # Apply edge detection only in the expanded region
465
+ expanded_region = (expanded_mask > 0) & (mask_array == 0)
466
+ object_edges = np.zeros_like(edges)
467
+ object_edges[expanded_region] = edges[expanded_region]
468
+
469
+ # Close gaps to form complete objects
470
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
471
+ object_mask = cv2.morphologyEx(object_edges, cv2.MORPH_CLOSE, kernel_close)
472
+ object_mask = cv2.dilate(object_mask, kernel_close, iterations=1)
473
+
474
+ # Combine with original mask
475
+ final_mask = np.maximum(mask_array, object_mask)
476
+
477
+ logger.info("✅ Scene focus adjustment completed")
478
+ return Image.fromarray(final_mask)
479
+
480
+ except Exception as e:
481
+ logger.error(f"❌ Scene focus adjustment failed: {e}")
482
+ return mask
483
+
484
+ def create_gradient_based_mask(self, original_image: Image.Image, mode: str = "center", focus_mode: str = "person") -> Image.Image:
485
+ """
486
+ Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
487
+ Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
488
+ """
489
+ width, height = original_image.size
490
+ logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
491
+
492
+ if mode == "center":
493
+ # Try using deep learning models for intelligent foreground extraction
494
+ logger.info("🤖 Attempting deep learning mask generation...")
495
+ dl_mask = self.try_deep_learning_mask(original_image)
496
+ if dl_mask is not None:
497
+ logger.info("✅ Using deep learning generated mask")
498
+ # Apply focus mode adjustments to deep learning mask
499
+ if focus_mode == "scene":
500
+ dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
501
+ return dl_mask
502
+
503
+ # Fallback to traditional method
504
+ logger.info("🔄 Deep learning failed, using traditional gradient-based method")
505
+ img_array = np.array(original_image.convert('RGB'))
506
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
507
+
508
+ # First-order derivatives: use Sobel operator for edge detection
509
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
510
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
511
+ gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
512
+
513
+ # Second-order derivatives: use Laplacian operator for texture change detection
514
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=3)
515
+ laplacian_abs = np.abs(laplacian)
516
+
517
+ # Combine first and second order derivatives
518
+ combined_edges = gradient_magnitude * 0.7 + laplacian_abs * 0.3
519
+ combined_edges = (combined_edges / np.max(combined_edges)) * 255
520
+
521
+ # Threshold processing to find strong edges
522
+ _, edge_binary = cv2.threshold(combined_edges.astype(np.uint8), 20, 255, cv2.THRESH_BINARY)
523
+
524
+ # Morphological operations to connect edges
525
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
526
+ edge_binary = cv2.morphologyEx(edge_binary, cv2.MORPH_CLOSE, kernel)
527
+
528
+ # Find contours and create mask
529
+ contours, _ = cv2.findContours(edge_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
530
+
531
+ if contours:
532
+ # Find largest contour (main subject)
533
+ largest_contour = max(contours, key=cv2.contourArea)
534
+ contour_mask = np.zeros((height, width), dtype=np.uint8)
535
+ cv2.fillPoly(contour_mask, [largest_contour], 255)
536
+
537
+ # Create foreground enhancement mask: specially protect dark regions
538
+ dark_mask = (gray < 90).astype(np.uint8) * 255
539
+ morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
540
+ dark_mask = cv2.morphologyEx(dark_mask, cv2.MORPH_CLOSE, morph_kernel, iterations=1)
541
+ dark_mask = cv2.dilate(dark_mask, morph_kernel, iterations=2)
542
+ contour_mask = cv2.bitwise_or(contour_mask, dark_mask)
543
+
544
+ # Get core foreground: clean holes and fill gaps
545
+ close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
546
+ core_mask = cv2.morphologyEx(contour_mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
547
+
548
+ open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
549
+ core_mask = cv2.morphologyEx(core_mask, cv2.MORPH_OPEN, open_kernel, iterations=1)
550
+
551
+ # Convert to binary core (0/255)
552
+ _, core_binary = cv2.threshold(core_mask, 127, 255, cv2.THRESH_BINARY)
553
+
554
+ # Keep only slight dilation to avoid foreground being eaten
555
+ dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
556
+ core_binary = cv2.dilate(core_binary, dilate_kernel, iterations=1)
557
+
558
+ # Distance transform feathering: shrink feathering range for sharp edges
559
+ FEATHER_PX = 4
560
+
561
+ # Calculate distance transform
562
+ core_float = core_binary.astype(np.float32) / 255.0
563
+ distances = cv2.distanceTransform((1 - core_float).astype(np.uint8), cv2.DIST_L2, 5)
564
+
565
+ # Create feathering mask: 0→FEATHER_PX linear mapping to 1→0
566
+ feather_mask = np.ones_like(distances)
567
+ edge_region = (distances > 0) & (distances <= FEATHER_PX)
568
+ feather_mask[edge_region] = 1.0 - (distances[edge_region] / FEATHER_PX)
569
+ feather_mask[distances > FEATHER_PX] = 0.0
570
+
571
+ # Apply double-smoothstep curve: make transition steeper, reduce semi-transparent halos
572
+ def double_smoothstep(t):
573
+ t = np.clip(t, 0, 1)
574
+ s1 = t * t * (3 - 2 * t)
575
+ return s1 * s1 * (3 - 2 * s1) # Equivalent to t^3 (10 - 15t + 6t^2)
576
+
577
+ # Combine core with feathering: core area keeps 255, edges use double_smoothstep feathering
578
+ final_alpha = np.zeros_like(distances)
579
+ final_alpha[core_binary > 127] = 1.0 # Core area
580
+ final_alpha[edge_region] = double_smoothstep(feather_mask[edge_region]) # Feathering area
581
+
582
+ # Convert to 0-255 range
583
+ final_mask = (final_alpha * 255).astype(np.uint8)
584
+
585
+ # Apply guided filter for edge-preserving smoothing
586
+ final_mask = self.apply_guided_filter(final_mask, original_image, radius=8, eps=0.01)
587
+
588
+ mask = Image.fromarray(final_mask)
589
+ else:
590
+ # Backup plan: use large ellipse
591
+ mask = Image.new('L', (width, height), 0)
592
+ draw = ImageDraw.Draw(mask)
593
+ center_x, center_y = width // 2, height // 2
594
+ width_radius = int(width * 0.45)
595
+ height_radius = int(width * 0.48)
596
+ draw.ellipse([
597
+ center_x - width_radius, center_y - height_radius,
598
+ center_x + width_radius, center_y + height_radius
599
+ ], fill=255)
600
+ # Apply guided filter instead of Gaussian blur
601
+ mask_array = np.array(mask)
602
+ mask_array = self.apply_guided_filter(mask_array, original_image, radius=10, eps=0.02)
603
+ mask = Image.fromarray(mask_array)
604
+
605
+ elif mode == "left_half":
606
+ # Keep original logic unchanged - ensure Snoopy and other functions work normally
607
+ mask = Image.new('L', (width, height), 0)
608
+ mask_array = np.array(mask)
609
+ mask_array[:, :width//2] = 255
610
+
611
+ transition_zone = width // 10
612
+ for i in range(transition_zone):
613
+ x_pos = width//2 + i
614
+ if x_pos < width:
615
+ alpha = 255 * (1 - i / transition_zone)
616
+ mask_array[:, x_pos] = int(alpha)
617
+
618
+ mask = Image.fromarray(mask_array)
619
+
620
+ elif mode == "right_half":
621
+ # Keep original logic unchanged - ensure Snoopy and other functions work normally
622
+ mask = Image.new('L', (width, height), 0)
623
+ mask_array = np.array(mask)
624
+ mask_array[:, width//2:] = 255
625
+
626
+ transition_zone = width // 10
627
+ for i in range(transition_zone):
628
+ x_pos = width//2 - i - 1
629
+ if x_pos >= 0:
630
+ alpha = 255 * (1 - i / transition_zone)
631
+ mask_array[:, x_pos] = int(alpha)
632
+
633
+ mask = Image.fromarray(mask_array)
634
+
635
+ elif mode == "full":
636
+ mask = Image.new('L', (width, height), 0)
637
+ draw = ImageDraw.Draw(mask)
638
+ center_x, center_y = width // 2, height // 2
639
+ radius = min(width, height) // 8
640
+
641
+ draw.ellipse([
642
+ center_x - radius, center_y - radius,
643
+ center_x + radius, center_y + radius
644
+ ], fill=255)
645
+
646
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=5))
647
+
648
+ return mask
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
- #(Apache 2.0 授權,包含 Wan2.2 LoRA 修復)
2
  git+https://github.com/linoytsaban/diffusers.git@wan22-loras
3
-
4
  gradio
5
- transformers
6
- accelerate
7
  safetensors
8
  sentencepiece
9
  peft
@@ -12,4 +11,15 @@ imageio-ffmpeg
12
  opencv-python
13
  pillow
14
  spaces
15
- torchao
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VividFlow I2V Dependencies
2
  git+https://github.com/linoytsaban/diffusers.git@wan22-loras
 
3
  gradio
4
+ transformers>=4.46.0
5
+ accelerate>=1.1.1
6
  safetensors
7
  sentencepiece
8
  peft
 
11
  opencv-python
12
  pillow
13
  spaces
14
+ torchao
15
+
16
+ # Background Generation Dependencies (SceneWeaver)
17
+ open_clip_torch
18
+ sentence-transformers
19
+ rembg[gpu]
20
+ scipy
21
+ opencv-contrib-python
22
+
23
+ # Core Dependencies
24
+ torch>=2.5.0
25
+ numpy
scene_templates.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Optional
3
+ from dataclasses import dataclass
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ @dataclass
8
+ class SceneTemplate:
9
+ """Data class representing a scene template"""
10
+ key: str
11
+ name: str
12
+ prompt: str
13
+ negative_extra: str
14
+ category: str
15
+ icon: str
16
+ guidance_scale: float = 7.5
17
+
18
+
19
+ class SceneTemplateManager:
20
+ """
21
+ Manages curated scene templates for background generation.
22
+ Provides categorized presets that users can select with one click.
23
+ """
24
+
25
+ # Scene template definitions
26
+ TEMPLATES: Dict[str, SceneTemplate] = {
27
+ # Professional Category
28
+ "office_modern": SceneTemplate(
29
+ key="office_modern",
30
+ name="Modern Office",
31
+ prompt="modern minimalist office interior, clean white desk, large floor-to-ceiling windows, natural daylight, professional corporate environment, soft shadows, contemporary furniture",
32
+ negative_extra="messy, cluttered, dark, old",
33
+ category="Professional",
34
+ icon="🏢",
35
+ guidance_scale=7.5
36
+ ),
37
+ "office_executive": SceneTemplate(
38
+ key="office_executive",
39
+ name="Executive Suite",
40
+ prompt="luxurious executive office, mahogany desk, leather chair, city skyline view through windows, warm ambient lighting, bookshelf, elegant professional setting",
41
+ negative_extra="cheap, cramped, messy",
42
+ category="Professional",
43
+ icon="👔",
44
+ guidance_scale=7.5
45
+ ),
46
+ "studio_white": SceneTemplate(
47
+ key="studio_white",
48
+ name="White Studio",
49
+ prompt="clean white photography studio background, professional lighting setup, seamless white backdrop, soft diffused light, minimal shadows",
50
+ negative_extra="colored, textured, dirty",
51
+ category="Professional",
52
+ icon="📷",
53
+ guidance_scale=8.0
54
+ ),
55
+ "coworking": SceneTemplate(
56
+ key="coworking",
57
+ name="Coworking Space",
58
+ prompt="modern coworking space, open plan office, plants, exposed brick, industrial chic design, natural light, collaborative environment",
59
+ negative_extra="empty, dark, boring",
60
+ category="Professional",
61
+ icon="💼",
62
+ guidance_scale=7.0
63
+ ),
64
+ "conference": SceneTemplate(
65
+ key="conference",
66
+ name="Conference Room",
67
+ prompt="modern conference room, large meeting table, glass walls, professional presentation screen, bright corporate lighting, clean minimal design",
68
+ negative_extra="small, cramped, outdated",
69
+ category="Professional",
70
+ icon="🤝",
71
+ guidance_scale=7.5
72
+ ),
73
+
74
+ # Nature Category
75
+ "beach_sunset": SceneTemplate(
76
+ key="beach_sunset",
77
+ name="Sunset Beach",
78
+ prompt="beautiful tropical beach at golden hour sunset, palm trees silhouette, calm turquoise ocean waves, warm orange and pink sky, soft sand, paradise vacation vibes",
79
+ negative_extra="storm, rain, crowded, trash",
80
+ category="Nature",
81
+ icon="🏖️",
82
+ guidance_scale=7.0
83
+ ),
84
+ "forest_enchanted": SceneTemplate(
85
+ key="forest_enchanted",
86
+ name="Enchanted Forest",
87
+ prompt="magical enchanted forest, sunlight streaming through tall trees, lush green foliage, mystical atmosphere, morning mist, fairy tale woodland",
88
+ negative_extra="dead trees, dark, scary, barren",
89
+ category="Nature",
90
+ icon="🌲",
91
+ guidance_scale=7.0
92
+ ),
93
+ "mountain_scenic": SceneTemplate(
94
+ key="mountain_scenic",
95
+ name="Mountain Vista",
96
+ prompt="breathtaking mountain landscape, snow-capped peaks, alpine meadow, clear blue sky, majestic scenic view, pristine nature, peaceful atmosphere",
97
+ negative_extra="industrial, polluted, crowded",
98
+ category="Nature",
99
+ icon="🏔️",
100
+ guidance_scale=7.5
101
+ ),
102
+ "garden_spring": SceneTemplate(
103
+ key="garden_spring",
104
+ name="Spring Garden",
105
+ prompt="beautiful spring flower garden, colorful blooming flowers, roses and tulips, manicured hedges, sunny day, botanical paradise, fresh and vibrant",
106
+ negative_extra="dead, winter, wilted, dry",
107
+ category="Nature",
108
+ icon="🌸",
109
+ guidance_scale=7.0
110
+ ),
111
+ "lake_serene": SceneTemplate(
112
+ key="lake_serene",
113
+ name="Serene Lake",
114
+ prompt="peaceful serene lake at dawn, mirror-like water reflection, surrounding mountains, soft morning light, tranquil atmosphere, pristine natural beauty",
115
+ negative_extra="stormy, polluted, industrial",
116
+ category="Nature",
117
+ icon="🏞️",
118
+ guidance_scale=7.0
119
+ ),
120
+ "cherry_blossom": SceneTemplate(
121
+ key="cherry_blossom",
122
+ name="Cherry Blossom",
123
+ prompt="stunning cherry blossom trees in full bloom, pink sakura petals falling gently, Japanese garden aesthetic, soft spring sunlight, romantic atmosphere",
124
+ negative_extra="winter, dead, brown, wilted",
125
+ category="Nature",
126
+ icon="🌸",
127
+ guidance_scale=7.0
128
+ ),
129
+
130
+ # Urban Category
131
+ "city_skyline": SceneTemplate(
132
+ key="city_skyline",
133
+ name="City Skyline",
134
+ prompt="modern city skyline at blue hour, impressive skyscrapers, glass buildings reflecting sunset, urban metropolitan view, cinematic atmosphere",
135
+ negative_extra="slums, dirty, abandoned, ruins",
136
+ category="Urban",
137
+ icon="🌆",
138
+ guidance_scale=7.5
139
+ ),
140
+ "cafe_cozy": SceneTemplate(
141
+ key="cafe_cozy",
142
+ name="Cozy Cafe",
143
+ prompt="warm cozy coffee shop interior, wooden furniture, ambient lighting, exposed brick walls, plants, comfortable atmosphere, artisan cafe vibes",
144
+ negative_extra="fast food, plastic, harsh lighting",
145
+ category="Urban",
146
+ icon="☕",
147
+ guidance_scale=7.0
148
+ ),
149
+ "street_european": SceneTemplate(
150
+ key="street_european",
151
+ name="European Street",
152
+ prompt="charming European cobblestone street, historic buildings, outdoor cafe, flowers on balconies, warm afternoon light, romantic Paris or Rome vibes",
153
+ negative_extra="modern, industrial, ugly, dirty",
154
+ category="Urban",
155
+ icon="🏛️",
156
+ guidance_scale=7.0
157
+ ),
158
+ "night_neon": SceneTemplate(
159
+ key="night_neon",
160
+ name="Neon Nightlife",
161
+ prompt="vibrant city nightlife scene, neon lights and signs, urban night atmosphere, colorful reflections on wet street, cyberpunk aesthetic, electric energy",
162
+ negative_extra="daytime, boring, plain",
163
+ category="Urban",
164
+ icon="🌃",
165
+ guidance_scale=6.5
166
+ ),
167
+ "rooftop_view": SceneTemplate(
168
+ key="rooftop_view",
169
+ name="Rooftop Terrace",
170
+ prompt="luxury rooftop terrace, city panoramic view, modern outdoor furniture, string lights, sunset golden hour, sophisticated urban oasis",
171
+ negative_extra="cheap, dirty, crowded",
172
+ category="Urban",
173
+ icon="🏙️",
174
+ guidance_scale=7.5
175
+ ),
176
+
177
+ # Artistic Category
178
+ "gradient_soft": SceneTemplate(
179
+ key="gradient_soft",
180
+ name="Soft Gradient",
181
+ prompt="smooth soft gradient background, pastel colors blending beautifully, pink to blue to purple transition, dreamy aesthetic, professional portrait backdrop",
182
+ negative_extra="harsh, noisy, textured, busy",
183
+ category="Artistic",
184
+ icon="🎨",
185
+ guidance_scale=8.0
186
+ ),
187
+ "abstract_modern": SceneTemplate(
188
+ key="abstract_modern",
189
+ name="Modern Abstract",
190
+ prompt="modern abstract art background, geometric shapes, bold colors, contemporary design, artistic composition, museum gallery aesthetic",
191
+ negative_extra="realistic, plain, boring",
192
+ category="Artistic",
193
+ icon="🖼️",
194
+ guidance_scale=6.5
195
+ ),
196
+ "vintage_retro": SceneTemplate(
197
+ key="vintage_retro",
198
+ name="Vintage Retro",
199
+ prompt="vintage retro aesthetic background, warm sepia tones, nostalgic 70s vibes, film grain texture, classic photography style, timeless elegance",
200
+ negative_extra="modern, digital, cold, harsh",
201
+ category="Artistic",
202
+ icon="📻",
203
+ guidance_scale=7.0
204
+ ),
205
+ "watercolor_dream": SceneTemplate(
206
+ key="watercolor_dream",
207
+ name="Watercolor Dream",
208
+ prompt="beautiful watercolor painting background, soft flowing colors, artistic brush strokes, dreamy ethereal atmosphere, delicate artistic aesthetic",
209
+ negative_extra="digital, sharp, photorealistic",
210
+ category="Artistic",
211
+ icon="🖌️",
212
+ guidance_scale=6.5
213
+ ),
214
+
215
+ # Seasonal Category
216
+ "autumn_foliage": SceneTemplate(
217
+ key="autumn_foliage",
218
+ name="Autumn Foliage",
219
+ prompt="beautiful autumn scenery, vibrant fall foliage, orange red and golden leaves, maple trees, warm sunlight filtering through, cozy seasonal atmosphere",
220
+ negative_extra="spring, summer, green, snow",
221
+ category="Seasonal",
222
+ icon="🍂",
223
+ guidance_scale=7.0
224
+ ),
225
+ "winter_snow": SceneTemplate(
226
+ key="winter_snow",
227
+ name="Winter Wonderland",
228
+ prompt="magical winter wonderland, fresh white snow covering everything, snow-laden pine trees, soft snowfall, peaceful cold atmosphere, holiday season vibes",
229
+ negative_extra="summer, green, rain, mud",
230
+ category="Seasonal",
231
+ icon="❄️",
232
+ guidance_scale=7.0
233
+ ),
234
+ "summer_tropical": SceneTemplate(
235
+ key="summer_tropical",
236
+ name="Tropical Summer",
237
+ prompt="vibrant tropical summer scene, lush palm trees, bright sunny day, exotic flowers, paradise vacation destination, warm and inviting atmosphere",
238
+ negative_extra="winter, cold, snow, gray",
239
+ category="Seasonal",
240
+ icon="🌴",
241
+ guidance_scale=7.0
242
+ ),
243
+ "spring_meadow": SceneTemplate(
244
+ key="spring_meadow",
245
+ name="Spring Meadow",
246
+ prompt="beautiful spring meadow, wildflowers blooming, fresh green grass, butterflies, soft warm sunlight, renewal and new beginnings, pastoral beauty",
247
+ negative_extra="winter, autumn, dead, dry",
248
+ category="Seasonal",
249
+ icon="🌷",
250
+ guidance_scale=7.0
251
+ ),
252
+ }
253
+
254
+ # Category display order
255
+ CATEGORIES = ["Professional", "Nature", "Urban", "Artistic", "Seasonal"]
256
+
257
+ def __init__(self):
258
+ """Initialize the scene template manager"""
259
+ logger.info(f"SceneTemplateManager initialized with {len(self.TEMPLATES)} templates")
260
+
261
+ def get_all_templates(self) -> Dict[str, SceneTemplate]:
262
+ """Get all available templates"""
263
+ return self.TEMPLATES
264
+
265
+ def get_template(self, key: str) -> Optional[SceneTemplate]:
266
+ """Get a specific template by key"""
267
+ return self.TEMPLATES.get(key)
268
+
269
+ def get_templates_by_category(self, category: str) -> List[SceneTemplate]:
270
+ """Get all templates in a specific category"""
271
+ return [t for t in self.TEMPLATES.values() if t.category == category]
272
+
273
+ def get_categories(self) -> List[str]:
274
+ """Get list of all categories in display order"""
275
+ return self.CATEGORIES
276
+
277
+ def get_template_choices_sorted(self) -> List[str]:
278
+ """
279
+ Get template choices formatted for Gradio dropdown.
280
+ Returns list of display strings sorted A-Z: "🏢 Modern Office"
281
+ """
282
+ display_list = []
283
+ for key, template in self.TEMPLATES.items():
284
+ display_name = f"{template.icon} {template.name}"
285
+ display_list.append(display_name)
286
+
287
+ # Sort alphabetically by name (ignoring emoji)
288
+ display_list.sort(key=lambda x: x.split(' ', 1)[1] if ' ' in x else x)
289
+ return display_list
290
+
291
+ def get_template_key_from_display(self, display_name: str) -> Optional[str]:
292
+ """
293
+ Get template key from display name.
294
+ Example: "🏢 Modern Office" -> "office_modern"
295
+ """
296
+ if not display_name:
297
+ return None
298
+
299
+ for key, template in self.TEMPLATES.items():
300
+ if f"{template.icon} {template.name}" == display_name:
301
+ return key
302
+ return None
303
+
304
+ def get_prompt_for_template(self, key: str) -> Optional[str]:
305
+ """Get the prompt string for a template"""
306
+ template = self.get_template(key)
307
+ return template.prompt if template else None
308
+
309
+ def get_negative_prompt_for_template(
310
+ self,
311
+ key: str,
312
+ base_negative: str = "blurry, low quality, distorted, people, characters"
313
+ ) -> str:
314
+ """Get combined negative prompt for a template"""
315
+ template = self.get_template(key)
316
+ if template and template.negative_extra:
317
+ return f"{base_negative}, {template.negative_extra}"
318
+ return base_negative
319
+
320
+ def get_guidance_scale_for_template(self, key: str) -> float:
321
+ """Get the recommended guidance scale for a template"""
322
+ template = self.get_template(key)
323
+ return template.guidance_scale if template else 7.5
324
+
325
+ def build_gallery_html(self) -> str:
326
+ """
327
+ Build HTML for the scene template gallery.
328
+ Returns HTML string for display in Gradio.
329
+ """
330
+ html_parts = ['<div class="scene-gallery">']
331
+
332
+ for category in self.CATEGORIES:
333
+ templates = self.get_templates_by_category(category)
334
+ if not templates:
335
+ continue
336
+
337
+ html_parts.append(f'''
338
+ <div class="scene-category">
339
+ <h4 class="scene-category-title">{category}</h4>
340
+ <div class="scene-grid">
341
+ ''')
342
+
343
+ for template in templates:
344
+ html_parts.append(f'''
345
+ <button class="scene-card" data-template="{template.key}" onclick="selectTemplate('{template.key}')">
346
+ <span class="scene-icon">{template.icon}</span>
347
+ <span class="scene-name">{template.name}</span>
348
+ </button>
349
+ ''')
350
+
351
+ html_parts.append('</div></div>')
352
+
353
+ html_parts.append('</div>')
354
+ return ''.join(html_parts)
355
+
356
+ def get_gallery_css(self) -> str:
357
+ """Get CSS styles for the scene gallery"""
358
+ return """
359
+ /* Scene Gallery Styles */
360
+ .scene-gallery {
361
+ margin: 16px 0;
362
+ }
363
+
364
+ .scene-category {
365
+ margin-bottom: 20px;
366
+ }
367
+
368
+ .scene-category-title {
369
+ font-size: 0.9rem;
370
+ font-weight: 600;
371
+ color: #475569;
372
+ margin-bottom: 12px;
373
+ padding-bottom: 8px;
374
+ border-bottom: 1px solid #e2e8f0;
375
+ }
376
+
377
+ .scene-grid {
378
+ display: grid;
379
+ grid-template-columns: repeat(auto-fill, minmax(100px, 1fr));
380
+ gap: 8px;
381
+ }
382
+
383
+ .scene-card {
384
+ display: flex;
385
+ flex-direction: column;
386
+ align-items: center;
387
+ justify-content: center;
388
+ padding: 12px 8px;
389
+ background: #f8fafc;
390
+ border: 1px solid #e2e8f0;
391
+ border-radius: 8px;
392
+ cursor: pointer;
393
+ transition: all 0.2s ease;
394
+ min-height: 70px;
395
+ }
396
+
397
+ .scene-card:hover {
398
+ background: #dbeafe;
399
+ border-color: #3b82f6;
400
+ transform: translateY(-2px);
401
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
402
+ }
403
+
404
+ .scene-card.selected {
405
+ background: #dbeafe;
406
+ border-color: #3b82f6;
407
+ box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3);
408
+ }
409
+
410
+ .scene-icon {
411
+ font-size: 1.5rem;
412
+ margin-bottom: 4px;
413
+ }
414
+
415
+ .scene-name {
416
+ font-size: 0.75rem;
417
+ font-weight: 500;
418
+ color: #1e293b;
419
+ text-align: center;
420
+ line-height: 1.2;
421
+ }
422
+
423
+ @media (max-width: 768px) {
424
+ .scene-grid {
425
+ grid-template-columns: repeat(3, 1fr);
426
+ }
427
+ }
428
+ """
ui_manager.py CHANGED
@@ -1,20 +1,35 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from typing import Tuple
 
 
 
4
  from FlowFacade import FlowFacade
 
 
5
  from css_style import DELTAFLOW_CSS
6
  from prompt_examples import PROMPT_EXAMPLES
7
 
 
 
 
 
 
 
 
 
8
 
9
  class UIManager:
10
- def __init__(self, facade: FlowFacade):
11
  self.facade = facade
 
 
12
 
13
  def create_interface(self) -> gr.Blocks:
14
  with gr.Blocks(
15
  theme=gr.themes.Soft(),
16
  css=DELTAFLOW_CSS,
17
- title="VividFlow - Fast AI Image to Video"
18
  ) as interface:
19
 
20
  # Header
@@ -22,276 +37,523 @@ class UIManager:
22
  <div class="header-container">
23
  <h1 class="header-title">🌊 VividFlow</h1>
24
  <p class="header-subtitle">
25
- Bring Your Images to Life with AI Magic ✨<br>
26
- Transform any still image into dynamic, cinematic videos
27
  </p>
28
  </div>
29
  """)
30
 
31
- with gr.Row():
32
- # Left Panel: Input
33
- with gr.Column(scale=1, elem_classes="input-card"):
34
- gr.Markdown("### 📤 Input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- image_input = gr.Image(
37
- label="Upload Image (any type: photo, art, cartoon, etc.)",
38
- type="pil",
39
- elem_classes="image-upload",
40
- height=320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
- resolution_info = gr.Markdown(
44
- value="",
45
- visible=False,
46
- elem_classes="info-text"
 
 
47
  )
48
 
49
- prompt_input = gr.Textbox(
50
- label="Motion Instruction",
51
- placeholder="Describe camera movements (zoom, pan, orbit) and subject actions (head turn, hair flow, expression change). Be specific and cinematic! Example: 'Camera slowly zooms in, subject's eyes sparkle, hair flows gently in wind'",
52
- lines=3,
53
- max_lines=6
 
54
  )
55
 
56
- # Quick preset selector
57
- category_dropdown = gr.Dropdown(
58
- choices=list(PROMPT_EXAMPLES.keys()),
59
- label="💡 Quick Prompt Category",
60
- value="💃 Fashion / Beauty (Facial Only)",
61
- interactive=True
62
  )
63
 
64
- example_dropdown = gr.Dropdown(
65
- choices=PROMPT_EXAMPLES["💃 Fashion / Beauty (Facial Only)"],
66
- label="Example Prompts (click to use)",
67
- value=None,
68
- interactive=True
69
  )
70
 
71
- # Quality tips banner (blue)
72
- gr.HTML("""
73
- <div class="quality-banner">
74
- <strong>💡 Choose the Right Prompt Category:</strong><br>
75
- • <strong>💃 Facial Only:</strong> Safe for headshots and portraits without visible hands<br>
76
- • <strong>🙌 Hands Visible Required:</strong> Only use if hands are fully visible in your image (prevents artifacts)<br>
77
- <strong>🌄 Scenery/Objects:</strong> For landscapes, products, and abstract content
78
- </div>
79
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Generate button with patience banner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  gr.HTML("""
83
- <div class="patience-banner">
84
- <strong>⏱️ Models are Initializing!</strong><br>
85
- This first-time generation may take a moment while high-fidelity assets load into memory.<br>
86
- Grab a coffee ☕, and watch the magic happen! Subsequent runs will be significantly faster.
87
  </div>
88
  """)
89
 
90
- generate_btn = gr.Button(
91
- "🎬 Generate Video",
92
- variant="primary",
93
- elem_classes="primary-button",
94
- size="lg"
 
 
95
  )
96
 
97
- # Advanced settings
98
- with gr.Accordion("⚙️ Advanced Settings", open=False):
99
- duration_slider = gr.Slider(
100
- minimum=0.5,
101
- maximum=5.0,
102
- step=0.5,
103
- value=3.0,
104
- label="Duration (seconds)",
105
- info="3.0s = 49 frames, 5.0s = 81 frames (16fps)"
106
- )
107
-
108
- steps_slider = gr.Slider(
109
- minimum=4,
110
- maximum=12,
111
- step=1,
112
- value=4,
113
- label="Inference Steps",
114
- info="4-6 recommended • Higher steps = longer generation time"
115
- )
116
-
117
- with gr.Row():
118
- guidance_scale = gr.Slider(
119
- minimum=0.0,
120
- maximum=5.0,
121
- step=0.5,
122
- value=1.0,
123
- label="Guidance Scale (high noise)"
124
- )
125
-
126
- guidance_scale_2 = gr.Slider(
127
- minimum=0.0,
128
- maximum=5.0,
129
- step=0.5,
130
- value=1.0,
131
- label="Guidance Scale (low noise)"
132
- )
133
-
134
- with gr.Row():
135
- seed_input = gr.Number(
136
- label="Seed",
137
- value=42,
138
- precision=0,
139
- minimum=0,
140
- maximum=2147483647,
141
- info="Use same seed for reproducible results"
142
- )
143
-
144
- randomize_seed = gr.Checkbox(
145
- label="Randomize Seed",
146
- value=True,
147
- info="Generate different results each time"
148
- )
149
-
150
- enable_ai_prompt = gr.Checkbox(
151
- label="🤖 Enable AI Prompt Expansion (Qwen2.5)",
152
- value=False,
153
- info="Use AI to enhance your prompt (adds ~30s)"
154
- )
155
 
156
- # Right Panel: Output
157
- with gr.Column(scale=1, elem_classes="output-card"):
158
- gr.Markdown("### 🎥 Output")
 
 
 
 
 
159
 
160
- video_output = gr.Video(
161
- label="Generated Video",
162
- height=400,
163
- autoplay=True
 
 
 
164
  )
165
 
166
- with gr.Row():
167
- prompt_output = gr.Textbox(
168
- label="Final Prompt Used",
169
- lines=3,
170
- interactive=False,
171
- scale=3
172
- )
173
 
174
- seed_output = gr.Number(
175
- label="Seed Used",
176
- precision=0,
177
- interactive=False,
178
- scale=1
179
- )
180
 
181
- # Info section
182
- with gr.Row():
183
  gr.HTML("""
184
- <div class="info-box">
185
- <strong>ℹ️ Tips for Best Results:</strong><br>
186
- <strong>Use example prompts:</strong> Select a category above and click an example to get started<br>
187
- • <strong>Works with ANY image:</strong> Fashion portraits, anime, landscapes, products, abstract art, etc.<br>
188
- • <strong>For dramatic effects:</strong> Choose prompts with words like "explosive", "dramatic", "swirls", "transforms"<br>
189
- • <strong>Image quality matters:</strong> Higher resolution and clear subjects produce better results
190
  </div>
191
  """)
192
 
193
- # Footer
194
- gr.HTML("""
195
- <div class="footer">
196
- <p style="font-size: 0.9rem;">
197
- <strong>Powered by:</strong><br>
198
- <a href="https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers" target="_blank" style="color: #6366f1; text-decoration: none;">Wan2.2-I2V-A14B</a> (Wan-AI, optimized by <a href="https://huggingface.co/cbensimon" target="_blank" style="color: #6366f1; text-decoration: none;">cbensimon</a>)
199
- · Lightning LoRA (<a href="https://huggingface.co/Kijai/WanVideo_comfy" target="_blank" style="color: #6366f1; text-decoration: none;">Lightx2v</a>)
200
- · <a href="https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct" target="_blank" style="color: #6366f1; text-decoration: none;">Qwen2.5-0.5B</a>
201
- </p>
202
- </div>
203
- """)
204
-
205
- def update_examples(category):
206
- return gr.update(choices=PROMPT_EXAMPLES[category], value=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- def fill_prompt(selected_example):
209
- return selected_example if selected_example else ""
210
-
211
- def show_resolution_info(image):
212
- if image is None:
213
- return gr.update(value="", visible=False)
214
-
215
- from PIL import Image
216
- original_w, original_h = image.size
217
- resized_image = self.facade.video_engine.resize_image(image)
218
- output_w, output_h = resized_image.width, resized_image.height
219
-
220
- info = f"**📐 Resolution:** Input: {original_w}×{original_h} → Output: {output_w}×{output_h}"
221
- return gr.update(value=info, visible=True)
222
-
223
- category_dropdown.change(fn=update_examples, inputs=[category_dropdown],
224
- outputs=[example_dropdown])
225
- example_dropdown.change(fn=fill_prompt, inputs=[example_dropdown],
226
- outputs=[prompt_input])
227
- image_input.change(fn=show_resolution_info, inputs=[image_input],
228
- outputs=[resolution_info])
229
-
230
- generate_btn.click(
231
- fn=self._handle_generation,
232
- inputs=[
233
- image_input,
234
- prompt_input,
235
- duration_slider,
236
- steps_slider,
237
- guidance_scale,
238
- guidance_scale_2,
239
- seed_input,
240
- randomize_seed,
241
- enable_ai_prompt
242
- ],
243
- outputs=[video_output, prompt_output, seed_output],
244
- show_progress=True
245
- )
246
 
247
- return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- def _handle_generation(self, image: Image.Image, prompt: str, duration: float,
250
- steps: int, guidance_1: float, guidance_2: float, seed: int,
251
- randomize: bool, enable_ai: bool,
252
- progress=gr.Progress()) -> Tuple[str, str, int]:
253
  try:
254
- if image is None:
255
- raise gr.Error("❌ Please upload an image")
256
- if not prompt or prompt.strip() == "":
257
- raise gr.Error("❌ Please provide a motion instruction")
258
- if not self.facade.validate_image(image):
259
- raise gr.Error("❌ Image dimensions invalid (256-4096px)")
260
 
261
- video_path, final_prompt, seed_used = self.facade.generate_video_from_image(
262
- image=image,
263
- user_instruction=prompt,
264
- duration_seconds=duration,
265
- num_inference_steps=steps,
266
- guidance_scale=guidance_1,
267
- guidance_scale_2=guidance_2,
268
- seed=int(seed),
269
- randomize_seed=randomize,
270
- enable_prompt_expansion=enable_ai,
271
- progress=progress
272
  )
273
 
274
- return video_path, final_prompt, seed_used
275
-
276
- except gr.Error:
277
- raise
 
 
 
 
 
 
 
278
 
279
  except Exception as e:
280
- import traceback
281
- import os
282
- error_msg = str(e)
283
-
284
- if os.environ.get('DEBUG'):
285
- print(f"\n✗ UI Error: {type(e).__name__}")
286
- print(traceback.format_exc())
287
-
288
- if "CUDA out of memory" in error_msg or "OutOfMemoryError" in error_msg:
289
- raise gr.Error("❌ GPU memory insufficient. Try reducing duration/steps or restart.")
290
- else:
291
- raise gr.Error(f"❌ Generation failed: {error_msg}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- def launch(self, share: bool = False, server_name: str = "0.0.0.0",
294
- server_port: int = None, **kwargs) -> None:
295
- interface = self.create_interface()
296
- interface.launch(share=share, server_name=server_name,
297
- server_port=server_port, **kwargs)
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from typing import Tuple, Optional, Dict, Any
4
+ import os
5
+ import logging
6
+
7
  from FlowFacade import FlowFacade
8
+ from BackgroundEngine import BackgroundEngine
9
+ from scene_templates import SceneTemplateManager
10
  from css_style import DELTAFLOW_CSS
11
  from prompt_examples import PROMPT_EXAMPLES
12
 
13
+ try:
14
+ import spaces
15
+ SPACES_AVAILABLE = True
16
+ except ImportError:
17
+ SPACES_AVAILABLE = False
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
 
22
  class UIManager:
23
+ def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine):
24
  self.facade = facade
25
+ self.background_engine = background_engine
26
+ self.template_manager = SceneTemplateManager()
27
 
28
  def create_interface(self) -> gr.Blocks:
29
  with gr.Blocks(
30
  theme=gr.themes.Soft(),
31
  css=DELTAFLOW_CSS,
32
+ title="VividFlow - AI Image Enhancement & Video Generation"
33
  ) as interface:
34
 
35
  # Header
 
37
  <div class="header-container">
38
  <h1 class="header-title">🌊 VividFlow</h1>
39
  <p class="header-subtitle">
40
+ AI-Powered Image Enhancement & Video Generation<br>
41
+ Transform images with background replacement, then bring them to life with AI
42
  </p>
43
  </div>
44
  """)
45
 
46
+ # Main Tabs
47
+ with gr.Tabs() as main_tabs:
48
+
49
+ # Tab 1: Image to Video (Original Functionality)
50
+ with gr.Tab("🎬 Image to Video"):
51
+ self._create_i2v_tab()
52
+
53
+ # Tab 2: Background Generation (New Feature)
54
+ with gr.Tab("🎨 Background Generation"):
55
+ self._create_background_tab()
56
+
57
+ # Footer
58
+ gr.HTML("""
59
+ <div class="footer">
60
+ <p>Powered by Wan2.2-I2V-A14B, SDXL, and OpenCLIP | Built with Gradio</p>
61
+ </div>
62
+ """)
63
+
64
+ return interface
65
 
66
+ def _create_i2v_tab(self):
67
+ """Create Image to Video tab (original VividFlow functionality)"""
68
+ with gr.Row():
69
+ # Left Panel: Input
70
+ with gr.Column(scale=1, elem_classes="input-card"):
71
+ gr.Markdown("### 📤 Input")
72
+
73
+ image_input = gr.Image(
74
+ label="Upload Image (any type: photo, art, cartoon, etc.)",
75
+ type="pil",
76
+ elem_classes="image-upload",
77
+ height=320
78
+ )
79
+
80
+ resolution_info = gr.Markdown(
81
+ value="",
82
+ visible=False,
83
+ elem_classes="info-text"
84
+ )
85
+
86
+ prompt_input = gr.Textbox(
87
+ label="Motion Instruction",
88
+ placeholder="Describe camera movements and subject actions...",
89
+ lines=3,
90
+ max_lines=6
91
+ )
92
+
93
+ category_dropdown = gr.Dropdown(
94
+ choices=list(PROMPT_EXAMPLES.keys()),
95
+ label="💡 Quick Prompt Category",
96
+ value="💃 Fashion / Beauty (Facial Only)",
97
+ interactive=True
98
+ )
99
+
100
+ example_dropdown = gr.Dropdown(
101
+ choices=PROMPT_EXAMPLES["💃 Fashion / Beauty (Facial Only)"],
102
+ label="Example Prompts (click to use)",
103
+ value=None,
104
+ interactive=True
105
+ )
106
+
107
+ gr.HTML("""
108
+ <div class="quality-banner">
109
+ <strong>💡 Choose the Right Prompt Category:</strong><br>
110
+ • <strong>💃 Facial Only:</strong> Safe for headshots without visible hands<br>
111
+ • <strong>🙌 Hands Visible Required:</strong> Only use if hands are fully visible<br>
112
+ • <strong>🌄 Scenery/Objects:</strong> For landscapes, products, abstract content
113
+ </div>
114
+ """)
115
+
116
+ gr.HTML("""
117
+ <div class="patience-banner">
118
+ <strong>⏱️ First-time loading may take a moment!</strong><br>
119
+ Subsequent runs will be much faster.
120
+ </div>
121
+ """)
122
+
123
+ generate_btn = gr.Button(
124
+ "🎬 Generate Video",
125
+ variant="primary",
126
+ elem_classes="primary-button",
127
+ size="lg"
128
+ )
129
+
130
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
131
+ duration_slider = gr.Slider(
132
+ minimum=0.5,
133
+ maximum=5.0,
134
+ value=3.0,
135
+ step=0.5,
136
+ label="Video Duration (seconds)"
137
  )
138
 
139
+ steps_slider = gr.Slider(
140
+ minimum=4,
141
+ maximum=25,
142
+ value=4,
143
+ step=1,
144
+ label="Quality Steps (4=Lightning Fast, 8-25=Higher Quality)"
145
  )
146
 
147
+ fps_slider = gr.Slider(
148
+ minimum=8,
149
+ maximum=24,
150
+ value=16,
151
+ step=1,
152
+ label="Frames Per Second"
153
  )
154
 
155
+ expand_prompt = gr.Checkbox(
156
+ label="AI Prompt Expansion (experimental)",
157
+ value=False
 
 
 
158
  )
159
 
160
+ randomize_seed = gr.Checkbox(
161
+ label="Randomize Seed",
162
+ value=True
 
 
163
  )
164
 
165
+ seed_input = gr.Number(
166
+ label="Manual Seed (if not randomized)",
167
+ value=42,
168
+ precision=0
169
+ )
170
+
171
+ # Right Panel: Output
172
+ with gr.Column(scale=1, elem_classes="output-card"):
173
+ gr.Markdown("### 🎥 Output")
174
+
175
+ video_output = gr.Video(
176
+ label="Generated Video",
177
+ elem_classes="video-player"
178
+ )
179
+
180
+ final_prompt_output = gr.Textbox(
181
+ label="Final Prompt Used",
182
+ interactive=False,
183
+ lines=2
184
+ )
185
+
186
+ seed_output = gr.Number(
187
+ label="Seed Used",
188
+ interactive=False,
189
+ precision=0
190
+ )
191
+
192
+ # Event handlers for I2V tab
193
+ def update_resolution_display(img):
194
+ if img is None:
195
+ return gr.update(visible=False)
196
+ w, h = img.size
197
+ new_w = (w // 16) * 16
198
+ new_h = (h // 16) * 16
199
+ return gr.update(
200
+ value=f"📐 **Resolution:** Input: {w}×{h} → Output: {new_w}×{new_h}",
201
+ visible=True
202
+ )
203
+
204
+ def category_changed(category):
205
+ if category in PROMPT_EXAMPLES:
206
+ return gr.update(choices=PROMPT_EXAMPLES[category], value=None)
207
+ return gr.update()
208
+
209
+ def example_selected(example):
210
+ return example if example else ""
211
+
212
+ image_input.change(
213
+ fn=update_resolution_display,
214
+ inputs=[image_input],
215
+ outputs=[resolution_info]
216
+ )
217
+
218
+ category_dropdown.change(
219
+ fn=category_changed,
220
+ inputs=[category_dropdown],
221
+ outputs=[example_dropdown]
222
+ )
223
+
224
+ example_dropdown.change(
225
+ fn=example_selected,
226
+ inputs=[example_dropdown],
227
+ outputs=[prompt_input]
228
+ )
229
+
230
+ generate_btn.click(
231
+ fn=self._generate_video_handler,
232
+ inputs=[
233
+ image_input, prompt_input, duration_slider,
234
+ steps_slider, fps_slider, expand_prompt,
235
+ randomize_seed, seed_input
236
+ ],
237
+ outputs=[video_output, final_prompt_output, seed_output]
238
+ )
239
+
240
+ def _generate_video_handler(
241
+ self,
242
+ image: Image.Image,
243
+ prompt: str,
244
+ duration: float,
245
+ steps: int,
246
+ fps: int,
247
+ expand_prompt: bool,
248
+ randomize_seed: bool,
249
+ seed: int
250
+ ) -> Tuple[str, str, int]:
251
+ """Handler for video generation"""
252
+ if image is None:
253
+ return None, "Please upload an image", 0
254
+
255
+ if not prompt.strip():
256
+ return None, "Please provide a motion prompt", 0
257
+
258
+ try:
259
+ video_path, final_prompt, seed_used = self.facade.generate_video_from_image(
260
+ image=image,
261
+ user_instruction=prompt,
262
+ duration_seconds=duration,
263
+ num_inference_steps=steps,
264
+ enable_prompt_expansion=expand_prompt,
265
+ randomize_seed=randomize_seed,
266
+ seed=seed
267
+ )
268
+ return video_path, final_prompt, seed_used
269
+
270
+ except Exception as e:
271
+ logger.error(f"Video generation failed: {e}")
272
+ return None, f"Error: {str(e)}", 0
273
+
274
+
275
+ def _create_background_tab(self):
276
+ """Create Background Generation tab (SceneWeaver functionality)"""
277
+ with gr.Row():
278
+ # Left Panel: Input
279
+ with gr.Column(scale=1, elem_classes="feature-card"):
280
+ gr.Markdown("### 📸 Upload & Configure")
281
 
282
+ gr.HTML("""
283
+ <div class="quality-banner">
284
+ <strong>💡 Best Results Tips:</strong><br>
285
+ • Clean portrait photos with simple backgrounds work best<br>
286
+ • Complex scenes (e.g., pets with grass) may need parameter adjustments<br>
287
+ • Use Advanced Options below to fine-tune edge blending
288
+ </div>
289
+ """)
290
+
291
+ bg_image_input = gr.Image(
292
+ label="Upload Your Image",
293
+ type="pil",
294
+ height=280
295
+ )
296
+
297
+ # Scene Template Selector
298
+ template_dropdown = gr.Dropdown(
299
+ label="Scene Templates (24 curated scenes A-Z)",
300
+ choices=[""] + self.template_manager.get_template_choices_sorted(),
301
+ value="",
302
+ info="Optional: Select a preset or describe your own",
303
+ elem_classes=["template-dropdown"]
304
+ )
305
+
306
+ bg_prompt_input = gr.Textbox(
307
+ label="Background Scene Description",
308
+ placeholder="Select a template above or describe your own scene...",
309
+ lines=3
310
+ )
311
+
312
+ combination_mode = gr.Dropdown(
313
+ label="Composition Mode",
314
+ choices=["center", "left_half", "right_half", "full"],
315
+ value="center",
316
+ info="center=Smart Center | full=Full Image"
317
+ )
318
+
319
+ focus_mode = gr.Dropdown(
320
+ label="Focus Mode",
321
+ choices=["person", "scene"],
322
+ value="person",
323
+ info="person=Tight Crop | scene=Include Surrounding"
324
+ )
325
+
326
+ with gr.Accordion("Advanced Options", open=False):
327
  gr.HTML("""
328
+ <div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
329
+ <strong>💡 When to Adjust:</strong><br>
330
+ <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
331
+ <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced. Helps diagnose edge issues.
332
  </div>
333
  """)
334
 
335
+ feather_radius_slider = gr.Slider(
336
+ label="Feather Radius (Edge Softness)",
337
+ minimum=0,
338
+ maximum=20,
339
+ value=0,
340
+ step=1,
341
+ info="Softens mask edges. Try 5-10 if edges look harsh."
342
  )
343
 
344
+ bg_negative_prompt = gr.Textbox(
345
+ label="Negative Prompt",
346
+ value="blurry, low quality, distorted, people, characters",
347
+ lines=2,
348
+ info="Prevents unwanted elements in background"
349
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ bg_steps_slider = gr.Slider(
352
+ label="Quality Steps",
353
+ minimum=15,
354
+ maximum=50,
355
+ value=25,
356
+ step=5,
357
+ info="Higher = better quality but slower"
358
+ )
359
 
360
+ bg_guidance_slider = gr.Slider(
361
+ label="Guidance Scale",
362
+ minimum=5.0,
363
+ maximum=15.0,
364
+ value=7.5,
365
+ step=0.5,
366
+ info="How strictly to follow prompt"
367
  )
368
 
369
+ generate_bg_btn = gr.Button(
370
+ "🎨 Generate Background",
371
+ variant="primary",
372
+ elem_classes="primary-button",
373
+ size="lg"
374
+ )
 
375
 
376
+ # Right Panel: Output
377
+ with gr.Column(scale=2, elem_classes="feature-card"):
378
+ gr.Markdown("### 🎭 Results Gallery")
 
 
 
379
 
 
 
380
  gr.HTML("""
381
+ <div class="patience-banner">
382
+ <strong>⏱️ First-time users:</strong> Initial model loading takes 1-2 minutes.
383
+ Subsequent generations are much faster (~30s).
 
 
 
384
  </div>
385
  """)
386
 
387
+ with gr.Tabs():
388
+ with gr.TabItem("Final Result"):
389
+ bg_combined_output = gr.Image(
390
+ label="Your Generated Image",
391
+ elem_classes=["result-gallery"]
392
+ )
393
+ with gr.TabItem("Background"):
394
+ bg_generated_output = gr.Image(
395
+ label="Generated Background",
396
+ elem_classes=["result-gallery"]
397
+ )
398
+ with gr.TabItem("Original"):
399
+ bg_original_output = gr.Image(
400
+ label="Processed Original",
401
+ elem_classes=["result-gallery"]
402
+ )
403
+ with gr.TabItem("Mask Preview"):
404
+ gr.HTML("""
405
+ <div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 8px; font-size: 13px;">
406
+ <strong>📐 How to Read:</strong> White = Original kept | Black = Background replaced<br>
407
+ Use this to diagnose edge quality. If edges are too harsh, increase Feather Radius.
408
+ </div>
409
+ """)
410
+ bg_mask_output = gr.Image(
411
+ label="Blending Mask",
412
+ elem_classes=["result-gallery"]
413
+ )
414
 
415
+ bg_status_output = gr.Textbox(
416
+ label="Status",
417
+ value="Ready to create! Upload an image and describe your vision.",
418
+ interactive=False,
419
+ elem_classes=["status-panel"]
420
+ )
421
+
422
+ with gr.Row():
423
+ clear_bg_btn = gr.Button(
424
+ "Clear All",
425
+ elem_classes=["secondary-button"]
426
+ )
427
+ memory_btn = gr.Button(
428
+ "Clean Memory",
429
+ elem_classes=["secondary-button"]
430
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ # Event handlers for Background Generation tab
433
+ def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
434
+ if not display_name:
435
+ return "", current_negative, 7.5
436
+
437
+ template_key = self.template_manager.get_template_key_from_display(display_name)
438
+ if not template_key:
439
+ return "", current_negative, 7.5
440
+
441
+ template = self.template_manager.get_template(template_key)
442
+ if template:
443
+ prompt = template.prompt
444
+ negative = self.template_manager.get_negative_prompt_for_template(
445
+ template_key, current_negative
446
+ )
447
+ guidance = template.guidance_scale
448
+ return prompt, negative, guidance
449
+
450
+ return "", current_negative, 7.5
451
+
452
+ template_dropdown.change(
453
+ fn=apply_template,
454
+ inputs=[template_dropdown, bg_negative_prompt],
455
+ outputs=[bg_prompt_input, bg_negative_prompt, bg_guidance_slider]
456
+ )
457
+
458
+ generate_bg_btn.click(
459
+ fn=self._generate_background_handler,
460
+ inputs=[
461
+ bg_image_input, bg_prompt_input, combination_mode,
462
+ focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
463
+ feather_radius_slider
464
+ ],
465
+ outputs=[
466
+ bg_combined_output, bg_generated_output,
467
+ bg_original_output, bg_mask_output, bg_status_output
468
+ ]
469
+ )
470
+
471
+ clear_bg_btn.click(
472
+ fn=lambda: (None, None, None, None, "Ready to create!"),
473
+ outputs=[
474
+ bg_combined_output, bg_generated_output,
475
+ bg_original_output, bg_mask_output, bg_status_output
476
+ ]
477
+ )
478
+
479
+ memory_btn.click(
480
+ fn=lambda: self.background_engine._memory_cleanup() or "Memory cleaned!",
481
+ outputs=[bg_status_output]
482
+ )
483
+
484
+ def _generate_background_handler(
485
+ self,
486
+ image: Image.Image,
487
+ prompt: str,
488
+ combination_mode: str,
489
+ focus_mode: str,
490
+ negative_prompt: str,
491
+ steps: int,
492
+ guidance: float,
493
+ feather_radius: int
494
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
495
+ """Handler for background generation"""
496
+ if image is None:
497
+ return None, None, None, None, "Please upload an image to get started!"
498
+
499
+ if not prompt.strip():
500
+ return None, None, None, None, "Please describe the background scene you'd like!"
501
 
 
 
 
 
502
  try:
503
+ # Apply ZeroGPU decorator if available
504
+ if SPACES_AVAILABLE:
505
+ generate_fn = spaces.GPU(duration=60)(self._background_generate_core)
506
+ else:
507
+ generate_fn = self._background_generate_core
 
508
 
509
+ result = generate_fn(
510
+ image, prompt, combination_mode, focus_mode,
511
+ negative_prompt, steps, guidance, feather_radius
 
 
 
 
 
 
 
 
512
  )
513
 
514
+ if result["success"]:
515
+ return (
516
+ result["combined_image"],
517
+ result["generated_scene"],
518
+ result["original_image"],
519
+ result["mask"],
520
+ "Image created successfully!"
521
+ )
522
+ else:
523
+ error_msg = result.get("error", "Something went wrong")
524
+ return None, None, None, None, f"Error: {error_msg}"
525
 
526
  except Exception as e:
527
+ logger.error(f"Background generation failed: {e}")
528
+ return None, None, None, None, f"Error: {str(e)}"
529
+
530
+ def _background_generate_core(
531
+ self,
532
+ image: Image.Image,
533
+ prompt: str,
534
+ combination_mode: str,
535
+ focus_mode: str,
536
+ negative_prompt: str,
537
+ steps: int,
538
+ guidance: float,
539
+ feather_radius: int
540
+ ) -> Dict[str, Any]:
541
+ """Core background generation with models"""
542
+ if not self.background_engine.is_initialized:
543
+ logger.info("Loading background generation models...")
544
+ self.background_engine.load_models()
545
+
546
+ result = self.background_engine.generate_and_combine(
547
+ original_image=image,
548
+ prompt=prompt,
549
+ combination_mode=combination_mode,
550
+ focus_mode=focus_mode,
551
+ negative_prompt=negative_prompt,
552
+ num_inference_steps=int(steps),
553
+ guidance_scale=float(guidance),
554
+ enable_prompt_enhancement=True,
555
+ feather_radius=int(feather_radius)
556
+ )
557
+
558
+ return result
559