avans06 commited on
Commit
8c93973
·
1 Parent(s): 9f78a38

init commit

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .vs
2
+ venv
3
+ tmp
4
+ *.pyc
5
+ models
6
+ images
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: SeedVR2 Image Upscaler
3
- emoji: 😻
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: SeedVR2 Image Upscaler
3
+ emoji: 🖼️
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.50.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu130
2
+
3
+ # Web UI
4
+ gradio==5.50.0
5
+
6
+ # Core numeric / vision
7
+ numpy
8
+ opencv-python
9
+
10
+ # PyTorch
11
+ torch
12
+ torchvision
13
+
14
+ # Hugging Face helper for downloading weights
15
+ huggingface-hub==0.36.0
16
+
17
+ # Utilities
18
+ tqdm
19
+
20
+ # SeedVR2
21
+ psutil
22
+ einops
23
+ diffusers
24
+ rotary-embedding-torch
25
+ omegaconf
26
+
27
+ gguf
28
+
29
+ triton; sys_platform != 'win32'
30
+ triton-windows; sys_platform == 'win32'
31
+
32
+ #
33
+ # flash-attn; sys_platform != 'win32'
34
+ # sageattention; sys_platform != 'win32'
35
+ #
36
+ # if sys_platform == 'win32'
37
+ # https://huggingface.co/lldacing/flash-attention-windows-wheel
38
+ # https://huggingface.co/ussoewwin/Flash-Attention-2_for_Windows
39
+ # https://github.com/woct0rdho/SageAttention
src/optimization/blockswap.py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BlockSwap Module for SeedVR2
3
+
4
+ This module implements dynamic block swapping between GPU and CPU memory
5
+ to enable running large models on limited VRAM systems.
6
+
7
+ Key Features:
8
+ - Dynamic transformer block offloading during inference
9
+ - Non-blocking GPU transfers for optimal performance
10
+ - RoPE computation fallback to CPU on OOM
11
+ - Minimal performance overhead with intelligent caching
12
+ - I/O component offloading for maximum memory savings
13
+ """
14
+
15
+ import time
16
+ import types
17
+ import torch
18
+ import weakref
19
+
20
+ from typing import Dict, Any, List, Optional
21
+ from .memory_manager import clear_memory
22
+ from .compatibility import call_rope_with_stability
23
+ from ..common.distributed import get_device
24
+
25
+
26
+ def is_blockswap_enabled(config: Optional[Dict[str, Any]]) -> bool:
27
+ """
28
+ Check if BlockSwap configuration indicates BlockSwap should be enabled.
29
+
30
+ BlockSwap is enabled if either blocks_to_swap > 0 OR swap_io_components is True.
31
+ This is the authoritative function for determining BlockSwap status from configuration.
32
+
33
+ Args:
34
+ config: BlockSwap configuration dictionary with optional keys:
35
+ - blocks_to_swap: Number of blocks to offload (0 = disabled)
36
+ - swap_io_components: Whether to offload I/O components
37
+
38
+ Returns:
39
+ True if BlockSwap should be active, False otherwise
40
+ """
41
+ if not config:
42
+ return False
43
+
44
+ blocks_to_swap = config.get("blocks_to_swap", 0)
45
+ swap_io_components = config.get("swap_io_components", False)
46
+
47
+ return blocks_to_swap > 0 or swap_io_components
48
+
49
+
50
+ def validate_blockswap_config(
51
+ block_swap_config: Optional[Dict[str, Any]],
52
+ dit_device: 'torch.device',
53
+ dit_offload_device: Optional['torch.device'],
54
+ debug: 'Debug'
55
+ ) -> Optional[Dict[str, Any]]:
56
+ """
57
+ Validate and potentially modify BlockSwap configuration.
58
+
59
+ Performs platform-specific validation and configuration adjustment:
60
+ - On macOS (MPS): Auto-disables BlockSwap since unified memory makes it meaningless
61
+ - On other platforms: Validates that offload_device is properly configured
62
+
63
+ This is the single authoritative validation point for BlockSwap configuration,
64
+ called early in configure_runner() before any model loading.
65
+
66
+ Args:
67
+ block_swap_config: BlockSwap configuration dictionary (may be None)
68
+ dit_device: Target device for DiT model inference
69
+ dit_offload_device: Device for offloading DiT blocks (may be None)
70
+ debug: Debug instance for logging warnings/errors
71
+
72
+ Returns:
73
+ Validated/modified block_swap_config (may be None or modified copy)
74
+
75
+ Raises:
76
+ ValueError: If BlockSwap is enabled but offload_device is invalid (non-MPS only)
77
+ """
78
+ if not is_blockswap_enabled(block_swap_config):
79
+ return block_swap_config
80
+
81
+ blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
82
+ swap_io_components = block_swap_config.get("swap_io_components", False)
83
+
84
+ # Check for macOS unified memory - BlockSwap is meaningless there
85
+ if dit_device.type == "mps":
86
+ debug.log(
87
+ f"BlockSwap disabled: macOS uses unified memory (no separate VRAM/RAM). "
88
+ f"Ignoring blocks_to_swap={blocks_to_swap}, swap_io_components={swap_io_components}",
89
+ level="WARNING", category="blockswap", force=True
90
+ )
91
+ # Return disabled config
92
+ return {
93
+ **block_swap_config,
94
+ "blocks_to_swap": 0,
95
+ "swap_io_components": False
96
+ }
97
+
98
+ # Validate offload_device is set and different from dit_device
99
+ offload_device_valid = (
100
+ dit_offload_device is not None and
101
+ str(dit_offload_device) != str(dit_device)
102
+ )
103
+
104
+ if not offload_device_valid:
105
+ config_details = []
106
+ if blocks_to_swap > 0:
107
+ config_details.append(f"blocks_to_swap={blocks_to_swap}")
108
+ if swap_io_components:
109
+ config_details.append("swap_io_components=True")
110
+
111
+ offload_str = str(dit_offload_device) if dit_offload_device else "none"
112
+ raise ValueError(
113
+ f"BlockSwap enabled ({', '.join(config_details)}) but dit_offload_device is invalid. "
114
+ f"Current: device='{dit_device}', dit_offload_device='{offload_str}'. "
115
+ f"BlockSwap requires offload_device on the DiT Model to be set and different from device. "
116
+ f"Set --dit_offload_device cpu or disable BlockSwap."
117
+ )
118
+
119
+ return block_swap_config
120
+
121
+
122
+ # Timing helpers marked to skip torch.compile tracing
123
+ # These functions are excluded from Dynamo's graph tracing to avoid warnings
124
+ # about non-traceable builtins like time.time(), but they still execute normally
125
+ @torch._dynamo.disable
126
+ def _get_swap_start_time(debug, enabled: bool) -> Optional[float]:
127
+ """Get start time for swap operation if debug is enabled."""
128
+ return time.time() if debug and enabled else None
129
+
130
+
131
+ @torch._dynamo.disable
132
+ def _log_swap_timing(debug, t_start: Optional[float], component_id, component_type: str) -> None:
133
+ """Log swap timing if start time was captured."""
134
+ if debug and t_start is not None:
135
+ debug.log_swap_time(
136
+ component_id=component_id,
137
+ duration=time.time() - t_start,
138
+ component_type=component_type
139
+ )
140
+
141
+
142
+ def get_module_memory_mb(module: torch.nn.Module) -> float:
143
+ """
144
+ Calculate memory usage of a module in MB.
145
+
146
+ Args:
147
+ module: PyTorch module to measure
148
+
149
+ Returns:
150
+ Memory usage in megabytes
151
+ """
152
+ total_bytes = sum(
153
+ param.nelement() * param.element_size()
154
+ for param in module.parameters()
155
+ if param.data is not None
156
+ )
157
+ return total_bytes / (1024 * 1024)
158
+
159
+
160
+ def apply_block_swap_to_dit(
161
+ runner: 'VideoDiffusionInfer',
162
+ block_swap_config: Dict[str, Any],
163
+ debug: 'Debug'
164
+ ) -> None:
165
+ """
166
+ Apply block swapping configuration to a DiT model with OOM protection.
167
+
168
+ This is the main entry point for configuring block swapping on a model.
169
+ Handles block selection, I/O component offloading, device placement, and
170
+ forward method wrapping for dynamic memory management.
171
+
172
+ Args:
173
+ runner: VideoDiffusionInfer instance containing the model
174
+ block_swap_config: Configuration dictionary with keys:
175
+ - blocks_to_swap: Number of blocks to swap (from the start)
176
+ - swap_io_components: Whether to offload I/O components
177
+ - enable_debug: Whether to enable debug logging
178
+ - offload_device: Device to offload to (default: 'cpu')
179
+ debug: Debug instance for logging (required)
180
+ """
181
+ # Early return if BlockSwap not enabled
182
+ if not is_blockswap_enabled(block_swap_config):
183
+ return
184
+
185
+ blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
186
+ swap_io_components = block_swap_config.get("swap_io_components", False)
187
+
188
+ # Early return only if both block swap and I/O swap are disabled
189
+ if blocks_to_swap <= 0 and not swap_io_components:
190
+ return
191
+
192
+ if debug is None:
193
+ if hasattr(runner, 'debug') and runner.debug is not None:
194
+ debug = runner.debug
195
+ else:
196
+ raise ValueError("Debug instance must be provided to apply_block_swap_to_dit")
197
+
198
+ debug.start_timer("apply_blockswap")
199
+
200
+ # Get the actual model (handle CompatibleDiT wrapper)
201
+ model = runner.dit
202
+ if hasattr(model, "dit_model"):
203
+ model = model.dit_model
204
+
205
+ # Determine devices
206
+ if hasattr(runner, '_dit_device'):
207
+ device = runner._dit_device
208
+ else:
209
+ device = get_device()
210
+ offload_device = block_swap_config.get("offload_device", torch.device('cpu'))
211
+
212
+ # Validate model structure
213
+ if not hasattr(model, "blocks"):
214
+ debug.log("Model doesn't have 'blocks' attribute for BlockSwap", level="ERROR", category="blockswap", force=True)
215
+ return
216
+
217
+ total_blocks = len(model.blocks)
218
+
219
+ # Clamp blocks_to_swap to available blocks BEFORE logging
220
+ effective_blocks = min(blocks_to_swap, total_blocks) if blocks_to_swap > 0 else 0
221
+
222
+ # Log configuration clearly based on what's enabled
223
+ block_text = "block" if effective_blocks <= 1 else "blocks"
224
+ if effective_blocks > 0 and swap_io_components:
225
+ debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} + I/O components offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
226
+ elif effective_blocks > 0:
227
+ debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
228
+ elif swap_io_components:
229
+ debug.log(f"BlockSwap: I/O components offloaded to {str(offload_device).upper()} (0/{total_blocks} blocks swapped)", category="blockswap", force=True)
230
+
231
+ # Configure model with blockswap attributes
232
+ if blocks_to_swap > 0:
233
+ model.blocks_to_swap = effective_blocks - 1 # Convert to 0-indexed
234
+ else:
235
+ # No block swapping, set to -1 so no blocks match the swap condition
236
+ model.blocks_to_swap = -1
237
+
238
+ model.main_device = device
239
+ model.offload_device = offload_device
240
+
241
+ # Configure I/O components
242
+ io_config = _configure_io_components(model, device, offload_device,
243
+ swap_io_components, debug)
244
+ memory_stats = _configure_blocks(model, device, offload_device, debug)
245
+ memory_stats['io_components'] = io_config['components']
246
+ memory_stats['io_memory_mb'] = io_config['memory_mb']
247
+ memory_stats['gpu_components'] = io_config['gpu_components']
248
+ memory_stats['io_gpu_memory_mb'] = io_config['gpu_memory_mb']
249
+
250
+ # Log memory summary
251
+ _log_memory_summary(memory_stats, offload_device, device, swap_io_components,
252
+ debug)
253
+
254
+ # Initialize Nunchaku-style async management object
255
+ if blocks_to_swap > 0:
256
+ # normalize device objects
257
+ if isinstance(device, str):
258
+ device = torch.device(device)
259
+ model._swap_stream = torch.cuda.Stream(device=device)
260
+ model._block_ready_events = {}
261
+
262
+ # Preload first swapped block to seed pipeline (non-blocking on swap_stream)
263
+ try:
264
+ first_idx = 0
265
+ if first_idx <= model.blocks_to_swap:
266
+ with torch.cuda.stream(model._swap_stream):
267
+ model.blocks[first_idx].to(device, non_blocking=True)
268
+ ev = torch.cuda.Event(blocking=False)
269
+ ev.record(model._swap_stream) # record on swap_stream -> event gets device-bound here
270
+ model._block_ready_events[first_idx] = ev
271
+ except Exception as e:
272
+ debug.log(f"Failed to initialize swap-stream prefetch: {e}", level="WARNING", category="blockswap", force=True)
273
+
274
+
275
+ # Wrap block forward methods for dynamic swapping (only if blocks_to_swap > 0)
276
+ if blocks_to_swap > 0:
277
+ for b, block in enumerate(model.blocks):
278
+ if b <= model.blocks_to_swap:
279
+ _wrap_block_forward(block, b, model, debug)
280
+
281
+ # Patch RoPE modules for robust error handling
282
+ _patch_rope_for_blockswap(model, debug)
283
+
284
+ # Mark BlockSwap as active
285
+ runner._blockswap_active = True
286
+
287
+ # Store configuration for debugging and cleanup
288
+ model._block_swap_config = {
289
+ "blocks_swapped": blocks_to_swap,
290
+ "swap_io_components": swap_io_components,
291
+ "total_blocks": total_blocks,
292
+ "offload_device": offload_device,
293
+ "main_device": device,
294
+ "offload_memory": memory_stats['offload_memory'],
295
+ "main_memory": memory_stats['main_memory']
296
+ }
297
+
298
+ # Protect model from being moved entirely
299
+ _protect_model_from_move(model, runner, debug)
300
+
301
+ debug.log("BlockSwap configuration complete", category="success")
302
+ debug.end_timer("apply_blockswap", "BlockSwap configuration application")
303
+
304
+
305
+ def _configure_io_components(
306
+ model: torch.nn.Module,
307
+ device: torch.device,
308
+ offload_device: torch.device,
309
+ swap_io_components: bool,
310
+ debug: 'Debug'
311
+ ) -> Dict[str, Any]:
312
+ """
313
+ Configure I/O component placement and wrapping with memory tracking.
314
+
315
+ Handles all non-block modules (embeddings, normalization layers, etc.) by
316
+ either keeping them on GPU or offloading them with dynamic swapping wrappers.
317
+
318
+ Args:
319
+ model: DiT model containing named children to configure
320
+ device: Main computation device (typically GPU)
321
+ offload_device: Device for offloaded components (typically CPU)
322
+ swap_io_components: If True, offload I/O components with dynamic swapping
323
+ debug: Debug instance for logging (required)
324
+
325
+ Returns:
326
+ Dictionary containing:
327
+ - components: List of offloaded component names
328
+ - memory_mb: Total memory of offloaded components in MB
329
+ - gpu_components: List of components remaining on GPU
330
+ - gpu_memory_mb: Total memory of GPU components in MB
331
+ """
332
+ io_components_offloaded = []
333
+ io_components_on_gpu = []
334
+ io_memory_mb = 0.0
335
+ io_gpu_memory_mb = 0.0
336
+
337
+ # Check for pin memory condition
338
+ use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")
339
+
340
+ # Handle I/O modules with dynamic swapping
341
+ for name, module in model.named_children():
342
+ if name != "blocks":
343
+ module_memory = get_module_memory_mb(module)
344
+
345
+ if swap_io_components:
346
+ module.to(offload_device)
347
+
348
+ # Enable Pin Memory for I/O components
349
+ if use_pin_memory:
350
+ for p in module.parameters():
351
+ if not p.is_pinned():
352
+ p.data = p.data.pin_memory()
353
+ for buf in module.buffers():
354
+ if not buf.is_pinned():
355
+ buf.data = buf.data.pin_memory()
356
+
357
+ _wrap_io_forward(module, name, model, debug)
358
+ io_components_offloaded.append(name)
359
+ io_memory_mb += module_memory
360
+ debug.log(f"{name} → {str(offload_device).upper()} ({module_memory:.2f}MB, dynamic swapping)", category="blockswap", indent_level=1)
361
+ else:
362
+ module.to(device)
363
+ io_components_on_gpu.append(name)
364
+ io_gpu_memory_mb += module_memory
365
+ debug.log(f"{name} → {str(device).upper()} ({module_memory:.2f}MB)", category="blockswap", indent_level=1)
366
+
367
+ return {
368
+ 'components': io_components_offloaded,
369
+ 'memory_mb': io_memory_mb,
370
+ 'gpu_components': io_components_on_gpu,
371
+ 'gpu_memory_mb': io_gpu_memory_mb
372
+ }
373
+
374
+
375
+ def _configure_blocks(
376
+ model: torch.nn.Module,
377
+ device: torch.device,
378
+ offload_device: torch.device,
379
+ debug: 'Debug'
380
+ ) -> Dict[str, float]:
381
+ """
382
+ Configure transformer block placement and calculate memory statistics.
383
+
384
+ Moves blocks to their designated devices based on model.blocks_to_swap
385
+ attribute. Blocks with index <= blocks_to_swap go to offload device,
386
+ others stay on main device.
387
+
388
+ Args:
389
+ model: DiT model with blocks attribute and blocks_to_swap configured
390
+ device: Main computation device for non-swapped blocks
391
+ offload_device: Device for swapped blocks
392
+ debug: Debug instance for logging (required)
393
+
394
+ Returns:
395
+ Dictionary containing:
396
+ - offload_memory: Total memory of offloaded blocks in MB
397
+ - main_memory: Total memory of blocks on main device in MB
398
+ - io_components: Empty list (populated by caller)
399
+ """
400
+ total_offload_memory = 0.0
401
+ total_main_memory = 0.0
402
+
403
+ # Check if we should pin memory (if offloading to CPU)
404
+ # Nunchaku uses pinned memory for faster async transfers
405
+ use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")
406
+
407
+ # Move blocks based on swap configuration
408
+ for b, block in enumerate(model.blocks):
409
+ block_memory = get_module_memory_mb(block)
410
+
411
+ if b > model.blocks_to_swap:
412
+ block.to(device)
413
+ total_main_memory += block_memory
414
+ else:
415
+ block.to(offload_device, non_blocking=False)
416
+ total_offload_memory += block_memory
417
+
418
+ # Enable Pin Memory optimization for CPU Offload transfer speed
419
+ if use_pin_memory:
420
+ for p in block.parameters():
421
+ if not p.is_pinned():
422
+ p.data = p.data.pin_memory()
423
+ for buf in block.buffers():
424
+ if not buf.is_pinned():
425
+ buf.data = buf.data.pin_memory()
426
+
427
+ # Ensure all buffers match their containing module's device
428
+ for b, block in enumerate(model.blocks):
429
+ target_device = device if b > model.blocks_to_swap else offload_device
430
+ for name, buffer in block.named_buffers():
431
+ if buffer.device != torch.device(target_device):
432
+ # Apply pinning if needed
433
+ if use_pin_memory and target_device.type == "cpu" and not buffer.is_pinned():
434
+ buffer.data = buffer.data.pin_memory()
435
+ buffer.data = buffer.data.to(target_device, non_blocking=False)
436
+
437
+ return {
438
+ "offload_memory": total_offload_memory,
439
+ "main_memory": total_main_memory,
440
+ "io_components": [] # Will be populated by caller
441
+ }
442
+
443
+
444
+ def _log_memory_summary(
445
+ memory_stats: Dict[str, float],
446
+ offload_device: torch.device,
447
+ device: torch.device,
448
+ swap_io_components: bool,
449
+ debug: 'Debug'
450
+ ) -> None:
451
+ """
452
+ Log comprehensive memory usage summary for BlockSwap configuration.
453
+
454
+ Displays detailed breakdown of memory distribution across devices,
455
+ including transformer blocks and I/O components.
456
+
457
+ Args:
458
+ memory_stats: Dictionary containing:
459
+ - offload_memory: Memory offloaded from blocks (MB)
460
+ - main_memory: Memory remaining on main device (MB)
461
+ - io_memory_mb: Memory from offloaded I/O components (MB)
462
+ - io_gpu_memory_mb: Memory from I/O components on GPU (MB)
463
+ offload_device: Device used for offloading
464
+ device: Main computation device
465
+ swap_io_components: Whether I/O components are being swapped
466
+ debug: Debug instance for logging (required)
467
+ """
468
+ debug.log("BlockSwap memory configuration:", category="blockswap")
469
+
470
+ # Log transformer blocks memory
471
+ blocks_offloaded = memory_stats['offload_memory']
472
+ blocks_on_gpu = memory_stats['main_memory']
473
+
474
+ offload_str = str(offload_device)
475
+ device_str = str(device)
476
+
477
+ if blocks_on_gpu == 0:
478
+ debug.log(f"Transformer blocks: {blocks_offloaded:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
479
+ else:
480
+ debug.log(f"Transformer blocks: {blocks_on_gpu:.2f}MB on {device_str}, {blocks_offloaded:.2f}MB on {offload_str}", category="blockswap", indent_level=1)
481
+
482
+ # Always log I/O components (whether swapping or not)
483
+ io_memory = memory_stats.get('io_memory_mb', 0.0)
484
+ io_gpu_memory = memory_stats.get('io_gpu_memory_mb', 0.0)
485
+
486
+ if swap_io_components and io_memory > 0:
487
+ io_components = memory_stats.get('io_components', [])
488
+ debug.log(f"I/O components: {io_memory:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
489
+ debug.log(f"{', '.join(io_components)}", category="blockswap", indent_level=2)
490
+ elif io_gpu_memory > 0:
491
+ io_gpu_components = memory_stats.get('gpu_components', [])
492
+ debug.log(f"I/O components: {io_gpu_memory:.2f}MB on {device_str}", category="blockswap", indent_level=1)
493
+ debug.log(f"{', '.join(io_gpu_components)}", category="blockswap", indent_level=2)
494
+
495
+ # Log total VRAM savings
496
+ total_offloaded = blocks_offloaded + (io_memory if swap_io_components else 0)
497
+ if total_offloaded > 0:
498
+ debug.log(f"Total VRAM saved: {total_offloaded:.2f}MB (~{total_offloaded/1024:.2f}GB)", category="blockswap", indent_level=1)
499
+
500
+
501
+ def _wrap_block_forward(
502
+ block: torch.nn.Module,
503
+ block_idx: int,
504
+ model: torch.nn.Module,
505
+ debug: 'Debug'
506
+ ) -> None:
507
+ """
508
+ Wrap individual transformer block forward for dynamic device swapping.
509
+
510
+ Implements Nunchaku-style pipelining: Prefetch Next -> Compute Current -> Offload Current.
511
+ https://github.com/nunchaku-tech/nunchaku/blob/main/nunchaku/models/utils.py
512
+
513
+ Creates a wrapped forward method that automatically:
514
+ 1. Moves block to GPU before computation
515
+ 2. Executes original forward pass
516
+ 3. Moves block back to offload device after computation
517
+ 4. Logs timing and manages memory pressure
518
+
519
+ Uses weak references to prevent memory leaks from closure retention.
520
+
521
+ Args:
522
+ block: Individual transformer block to wrap
523
+ block_idx: Index of this block in model.blocks
524
+ model: Parent DiT model (used for device references)
525
+ debug: Debug instance for logging (required)
526
+ """
527
+ if hasattr(block, '_original_forward'):
528
+ return # Already wrapped
529
+
530
+ # Store original forward method
531
+ original_forward = block.forward
532
+
533
+ # Create weak references
534
+ model_ref = weakref.ref(model)
535
+ debug_ref = weakref.ref(debug) if debug is not None else (lambda: None)
536
+
537
+ # Store block_idx on the block itself to avoid closure issues
538
+ block._block_idx = block_idx
539
+
540
+ def wrapped_forward(self, *args, **kwargs):
541
+ # Retrieve weak references
542
+ model = model_ref()
543
+ debug = debug_ref()
544
+
545
+ if not model:
546
+ # Model has been garbage collected, fall back to original
547
+ return original_forward(*args, **kwargs)
548
+
549
+ # Check if block swap is active for this block
550
+ if hasattr(model, 'blocks_to_swap') and self._block_idx <= model.blocks_to_swap:
551
+ # Use dynamo-disabled helper to get start time (avoids compilation warnings)
552
+ t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
553
+
554
+ # Only move to GPU if necessary
555
+ current_device = next(self.parameters()).device
556
+ target_device = torch.device(model.main_device)
557
+
558
+ # 1. Ensure CURRENT block is ready on GPU
559
+ # Check if we have a prefetch event waiting
560
+ if hasattr(model, '_block_ready_events') and self._block_idx in model._block_ready_events:
561
+ # Wait for the swap stream to finish moving this block
562
+ torch.cuda.current_stream().wait_event(model._block_ready_events[self._block_idx])
563
+ # Cleanup event
564
+ del model._block_ready_events[self._block_idx]
565
+ elif current_device != target_device:
566
+ # Fallback: First block or missed prefetch, move synchronously (but non-blocking)
567
+ debug.log(f"[blockswap] Block {self._block_idx} missing prefetch event, moving synchronously", level="WARNING", category="blockswap", force=True)
568
+ self.to(model.main_device, non_blocking=True)
569
+
570
+ # 2. Trigger Prefetch for NEXT block (Pipelining)
571
+ # Nunchaku logic: Start moving i+1 while i is computing
572
+ next_idx = self._block_idx + 1
573
+ if next_idx <= model.blocks_to_swap:
574
+ next_block = model.blocks[next_idx]
575
+ # Use the dedicated swap stream
576
+ with torch.cuda.stream(model._swap_stream):
577
+ next_block.to(model.main_device, non_blocking=True)
578
+ # Record event so next iteration knows when to wait
579
+ event = torch.cuda.Event(blocking=False)
580
+ event.record(model._swap_stream)
581
+ model._block_ready_events[next_idx] = event
582
+
583
+ # 3. Execute forward pass (Compute)
584
+ # This runs on the default stream, overlapping with the prefetch above
585
+ output = original_forward(*args, **kwargs)
586
+
587
+ # 4. Offload CURRENT block (Async)
588
+ # We record an event on compute stream to ensure we don't move data while it's being used
589
+ compute_done_event = torch.cuda.Event(blocking=False)
590
+ compute_done_event.record(torch.cuda.current_stream())
591
+
592
+ with torch.cuda.stream(model._swap_stream):
593
+ # Wait for compute to finish before moving memory out
594
+ model._swap_stream.wait_event(compute_done_event)
595
+ # Move back to offload device
596
+ self.to(model.offload_device, non_blocking=True)
597
+
598
+ # Use dynamo-disabled helper to log timing (avoids compilation warnings)
599
+ _log_swap_timing(debug, t_start, self._block_idx, "block (pipelined)")
600
+
601
+ # Only clear cache under memory pressure
602
+ clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
603
+ else:
604
+ output = original_forward(*args, **kwargs)
605
+
606
+ return output
607
+
608
+ # Bind the wrapped function as a method to the block
609
+ block.forward = types.MethodType(wrapped_forward, block)
610
+
611
+ # Store reference to original forward for cleanup
612
+ block._original_forward = original_forward
613
+
614
+
615
+ def _wrap_io_forward(
616
+ module: torch.nn.Module,
617
+ module_name: str,
618
+ model: torch.nn.Module,
619
+ debug: 'Debug'
620
+ ) -> None:
621
+ """
622
+ Wrap I/O component forward for dynamic device swapping.
623
+
624
+ Similar to _wrap_block_forward but for I/O components (embeddings,
625
+ normalization layers, etc.). Handles swapping between GPU and CPU
626
+ during forward passes.
627
+
628
+ Uses weak references to prevent circular dependencies and memory leaks.
629
+
630
+ Args:
631
+ module: I/O component module to wrap
632
+ module_name: Name identifier for logging (e.g., 'x_embedder')
633
+ model: Parent DiT model (used for device references)
634
+ debug: Debug instance for logging (required)
635
+ """
636
+ if hasattr(module, '_is_io_wrapped') and module._is_io_wrapped:
637
+ debug.log(f"Reusing existing I/O wrapper for {module_name}", category="reuse")
638
+ return # Already wrapped
639
+
640
+ # Store original forward method
641
+ original_forward = module.forward
642
+
643
+ # Create weak references
644
+ model_ref = weakref.ref(model)
645
+ debug_ref = weakref.ref(debug) if debug else lambda: None
646
+
647
+ # Store module name on the module itself
648
+ module._module_name = module_name
649
+ module._original_forward = original_forward
650
+
651
+ def wrapped_io_forward(self, *args, **kwargs):
652
+ # Retrieve weak references
653
+ model = model_ref()
654
+ debug = debug_ref()
655
+
656
+ if not model:
657
+ # Model has been garbage collected, fall back to original
658
+ return self._original_forward(*args, **kwargs)
659
+
660
+ # Use dynamo-disabled helper to get start time (avoids compilation warnings)
661
+ t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
662
+
663
+ # Check current device to avoid unnecessary moves
664
+ current_device = next(self.parameters()).device
665
+ target_device = torch.device(model.main_device)
666
+
667
+ # Move to GPU for computation if needed
668
+ if current_device != target_device:
669
+ self.to(model.main_device, non_blocking=False)
670
+
671
+ # Execute forward pass
672
+ output = self._original_forward(*args, **kwargs)
673
+
674
+ # Move back to offload device
675
+ self.to(model.offload_device, non_blocking=False)
676
+
677
+ # Use dynamo-disabled helper to log timing (avoids compilation warnings)
678
+ _log_swap_timing(debug, t_start, self._module_name, "I/O")
679
+
680
+ # Only clear cache under memory pressure
681
+ clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
682
+
683
+ return output
684
+
685
+ # Bind as a method
686
+ module.forward = types.MethodType(wrapped_io_forward, module)
687
+ module._is_io_wrapped = True
688
+
689
+ # Store module reference for restoration
690
+ if not hasattr(model, '_io_swappers'):
691
+ model._io_swappers = []
692
+ model._io_swappers.append((module, module_name))
693
+
694
+
695
+ def _patch_rope_for_blockswap(
696
+ model: torch.nn.Module,
697
+ debug: 'Debug'
698
+ ) -> None:
699
+ """
700
+ Patch RoPE (Rotary Position Embedding) modules for device-aware fallback.
701
+
702
+ Adds CPU fallback logic to RoPE modules to handle device mismatch errors
703
+ that can occur during BlockSwap operations. Complements the stability
704
+ wrapper from compatibility.py with device-specific error handling.
705
+
706
+ Args:
707
+ model: DiT model containing RoPE modules to patch
708
+ debug: Debug instance for logging (required)
709
+ """
710
+ rope_patches = []
711
+
712
+ for name, module in model.named_modules():
713
+ if "rope" in name.lower() and hasattr(module, "get_axial_freqs"):
714
+ # Skip if already wrapped by blockswap
715
+ if hasattr(module, '_blockswap_wrapped') and module._blockswap_wrapped:
716
+ continue
717
+
718
+ # Get current method (might be stability-wrapped)
719
+ current_method = module.get_axial_freqs
720
+
721
+ # Create device-aware wrapper with proper closure handling
722
+ def make_device_aware_wrapper(module_name, current_fn):
723
+ def device_aware_rope_wrapper(self, *args, **kwargs):
724
+ try:
725
+ # Try current method (original or stability-wrapped)
726
+ return current_fn(*args, **kwargs)
727
+ except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
728
+ error_msg = str(e).lower()
729
+ # Only handle device/memory specific errors
730
+ if any(x in error_msg for x in ["device", "memory", "allocation"]):
731
+ debug.log(f"RoPE OOM for {module_name}", level="WARNING", category="rope", force=True)
732
+ debug.log(f"Clearing RoPE cache and retrying", category="info", force=True)
733
+
734
+ # Get current device from parameters
735
+ try:
736
+ current_device = next(self.parameters()).device
737
+ except StopIteration:
738
+ # Fallback: use model's main_device if BlockSwap has set it, else use offload_device
739
+ if hasattr(model, 'main_device'):
740
+ current_device = torch.device(model.main_device)
741
+ elif hasattr(model, 'offload_device'):
742
+ current_device = torch.device(model.offload_device)
743
+
744
+ # Try clearing cache first (non-invasive fix)
745
+ if hasattr(current_fn, 'cache_clear'):
746
+ current_fn.cache_clear()
747
+ try:
748
+ # Retry on same device after clearing cache
749
+ return current_fn(*args, **kwargs)
750
+ except Exception as retry_error:
751
+ # Cache clear wasn't enough, need more drastic measures
752
+ debug.log(f"Cache clear insufficient for {module_name}, falling back to CPU", level="WARNING", category="rope", force=True)
753
+
754
+ # Fallback to CPU computation with stability
755
+ self.cpu()
756
+
757
+ try:
758
+ # Use call_rope_with_stability for CPU computation
759
+ # This ensures cache is cleared and autocast disabled
760
+ original_fn = getattr(self, '_original_get_axial_freqs', current_fn)
761
+ result = call_rope_with_stability(original_fn, *args, **kwargs)
762
+
763
+ # Move module back to original device
764
+ self.to(current_device)
765
+
766
+ # Move result to appropriate device if it's a tensor
767
+ if hasattr(result, 'to'):
768
+ target_device = args[0].device if len(args) > 0 and hasattr(args[0], 'device') else current_device
769
+ return result.to(target_device)
770
+ return result
771
+
772
+ except Exception as cpu_error:
773
+ # Always restore device even on error
774
+ self.to(current_device)
775
+ raise cpu_error
776
+ else:
777
+ # Not a device error, let it bubble up
778
+ raise
779
+
780
+ return device_aware_rope_wrapper
781
+
782
+ # Apply wrapper
783
+ module.get_axial_freqs = types.MethodType(
784
+ make_device_aware_wrapper(name, current_method),
785
+ module
786
+ )
787
+ module._blockswap_wrapped = True
788
+
789
+ # Store for cleanup (use original or previously stored)
790
+ original_method = getattr(module, '_original_get_axial_freqs', current_method)
791
+ rope_patches.append((module, original_method))
792
+
793
+ if rope_patches:
794
+ model._rope_patches = rope_patches
795
+ debug.log(f"Patched {len(rope_patches)} RoPE modules with device handling", category="success")
796
+
797
+
798
+ def _protect_model_from_move(
799
+ model: torch.nn.Module,
800
+ runner: 'VideoDiffusionInfer',
801
+ debug: 'Debug'
802
+ ) -> None:
803
+ """
804
+ Protect model from unintended full device movement during BlockSwap.
805
+
806
+ Wraps model.to() method to prevent other code from accidentally moving
807
+ the entire model to GPU, which would defeat BlockSwap's memory savings.
808
+ Allows movement only when explicitly bypassed via model flag.
809
+
810
+ Args:
811
+ model: DiT model to protect
812
+ runner: VideoDiffusionInfer instance (for active status check)
813
+ debug: Debug instance for logging (required)
814
+ """
815
+ if not hasattr(model, '_original_to'):
816
+ # Store runner reference as weak reference to avoid circular refs
817
+ model._blockswap_runner_ref = weakref.ref(runner)
818
+ model._original_to = model.to
819
+
820
+ # Define the protected method without closures
821
+ def protected_model_to(self, device, *args, **kwargs):
822
+ # Check if protection is temporarily bypassed for offloading
823
+ # Flag is stored on model itself (not runner) to survive runner recreation
824
+ if getattr(self, "_blockswap_bypass_protection", False):
825
+ # Protection bypassed, allow movement
826
+ if hasattr(self, '_original_to'):
827
+ return self._original_to(device, *args, **kwargs)
828
+
829
+ # Get configured offload device directly from model
830
+ blockswap_offload_device = "cpu" # default
831
+ if hasattr(self, "_block_swap_config"):
832
+ blockswap_offload_device = self._block_swap_config.get("offload_device", "cpu")
833
+
834
+ # Check if BlockSwap is currently active via runner weak reference
835
+ runner_ref = getattr(self, '_blockswap_runner_ref', None)
836
+ blockswap_is_active = False
837
+ if runner_ref:
838
+ runner_obj = runner_ref()
839
+ if runner_obj and hasattr(runner_obj, "_blockswap_active"):
840
+ blockswap_is_active = runner_obj._blockswap_active
841
+
842
+ # Block attempts to move model away from configured offload device when active
843
+ if blockswap_is_active and str(device) != str(blockswap_offload_device):
844
+ # Get debug instance from runner if available
845
+ debug_instance = None
846
+ if runner_ref:
847
+ runner_obj = runner_ref()
848
+ if runner_obj and hasattr(runner_obj, 'debug'):
849
+ debug_instance = runner_obj.debug
850
+
851
+ if debug_instance:
852
+ debug_instance.log(
853
+ f"Blocked attempt to move BlockSwap model from {blockswap_offload_device} to {device}",
854
+ level="WARNING", category="blockswap", force=True
855
+ )
856
+ return self
857
+
858
+ # Allow movement (either bypass is enabled or target is offload device)
859
+ if hasattr(self, '_original_to'):
860
+ return self._original_to(device, *args, **kwargs)
861
+ else:
862
+ # Fallback - shouldn't happen
863
+ return super(type(self), self).to(device, *args, **kwargs)
864
+
865
+ # Bind as a method to the model instance
866
+ model.to = types.MethodType(protected_model_to, model)
867
+
868
+
869
+ def set_blockswap_bypass(runner, bypass: bool, debug):
870
+ """
871
+ Set or unset bypass flag for BlockSwap protection.
872
+ Used for offloading to temporarily allow model movement.
873
+
874
+ Args:
875
+ runner: Runner instance with BlockSwap
876
+ bypass: True to bypass protection, False to enforce it
877
+ debug: Debug instance for logging
878
+ """
879
+ if not hasattr(runner, "_blockswap_active") or not runner._blockswap_active:
880
+ return
881
+
882
+ # Get the actual model (handle CompatibleDiT wrapper)
883
+ model = runner.dit
884
+ if hasattr(model, "dit_model"):
885
+ model = model.dit_model
886
+
887
+ # Store on model so it survives runner recreation during caching
888
+ model._blockswap_bypass_protection = bypass
889
+
890
+ if bypass:
891
+ debug.log("BlockSwap protection disabled to allow model DiT offloading", category="success")
892
+ else:
893
+ debug.log("BlockSwap protection renabled to avoid accidentally offloading the entire DiT model", category="success")
894
+
895
+
896
+ def cleanup_blockswap(runner, keep_state_for_cache=False):
897
+ """
898
+ Clean up BlockSwap configuration based on caching mode.
899
+
900
+ When caching (keep_state_for_cache=True):
901
+ - Keep all BlockSwap configuration intact
902
+ - Only mark as inactive for safety during non-inference operations
903
+
904
+ When not caching (keep_state_for_cache=False):
905
+ - Full cleanup of all BlockSwap state
906
+
907
+ Args:
908
+ runner: VideoDiffusionInfer instance to clean up
909
+ keep_state_for_cache: If True, preserve BlockSwap state for reuse
910
+ """
911
+ # Get debug instance from runner
912
+ if not hasattr(runner, 'debug') or runner.debug is None:
913
+ raise ValueError("Debug instance must be available on runner for cleanup_blockswap")
914
+
915
+ debug = runner.debug
916
+
917
+ # Get the actual model (handle CompatibleDiT wrapper)
918
+ model = runner.dit
919
+ if hasattr(model, "dit_model"):
920
+ model = model.dit_model
921
+
922
+ # Check if there's any BlockSwap state to clean up (check both runner and model)
923
+ has_blockswap_state = (
924
+ hasattr(runner, "_blockswap_active") or
925
+ hasattr(model, "_block_swap_config") or
926
+ hasattr(model, "_blockswap_bypass_protection")
927
+ )
928
+
929
+ if not has_blockswap_state:
930
+ return
931
+
932
+ debug.log("Starting BlockSwap cleanup", category="cleanup")
933
+
934
+ if keep_state_for_cache:
935
+ # Minimal cleanup for caching - just mark as inactive and allow offloading
936
+ # Everything else stays intact for fast reactivation
937
+ if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
938
+ if not getattr(model, "_blockswap_bypass_protection", False):
939
+ set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
940
+ runner._blockswap_active = False
941
+ debug.log("BlockSwap deactivated for caching (configuration preserved)", category="success")
942
+ return
943
+
944
+ # Full cleanup when not caching
945
+ # Get the actual model (handle CompatibleDiT wrapper)
946
+ model = runner.dit
947
+ if hasattr(model, "dit_model"):
948
+ model = model.dit_model
949
+
950
+ # 1. Restore block forward methods
951
+ if hasattr(model, 'blocks'):
952
+ restored_count = 0
953
+ for block in model.blocks:
954
+ if hasattr(block, '_original_forward'):
955
+ block.forward = block._original_forward
956
+ delattr(block, '_original_forward')
957
+ restored_count += 1
958
+
959
+ # Clean up wrapper attributes
960
+ for attr in ['_block_idx', '_model_ref', '_debug_ref', '_blockswap_wrapped']:
961
+ if hasattr(block, attr):
962
+ delattr(block, attr)
963
+
964
+ if restored_count > 0:
965
+ debug.log(f"Restored {restored_count} block forward methods", category="success")
966
+
967
+ # 2. Restore RoPE patches
968
+ if hasattr(model, '_rope_patches'):
969
+ for module, original_method in model._rope_patches:
970
+ module.get_axial_freqs = original_method
971
+ # Clean up wrapper attributes
972
+ for attr in ['_rope_wrapped', '_original_get_axial_freqs']:
973
+ if hasattr(module, attr):
974
+ delattr(module, attr)
975
+ debug.log(f"Restored {len(model._rope_patches)} RoPE methods", category="success")
976
+ delattr(model, '_rope_patches')
977
+
978
+ # 3. Restore I/O component forward methods and move to offload device
979
+ if hasattr(model, '_io_swappers'):
980
+ for module, module_name in model._io_swappers:
981
+ if hasattr(module, '_original_forward'):
982
+ module.forward = module._original_forward
983
+ # Clean up wrapper attributes
984
+ for attr in ['_original_forward', '_model_ref', '_debug_ref',
985
+ '_module_name', '_is_io_wrapped']:
986
+ if hasattr(module, attr):
987
+ delattr(module, attr)
988
+ debug.log(f"Restored {len(model._io_swappers)} I/O components", category="success")
989
+ delattr(model, '_io_swappers')
990
+
991
+ # Move all IO components to offload device during full cleanup
992
+ if hasattr(model, 'offload_device'):
993
+ offload_device = model.offload_device
994
+ moved_count = 0
995
+ for name, module in model.named_children():
996
+ if name != "blocks":
997
+ module.to(offload_device)
998
+ moved_count += 1
999
+ if moved_count > 0:
1000
+ debug.log(f"Moved {moved_count} IO components to offload device", category="success")
1001
+
1002
+ # 4. Restore original .to() method
1003
+ if hasattr(model, '_original_to'):
1004
+ model.to = model._original_to
1005
+ delattr(model, '_original_to')
1006
+ debug.log("Restored original .to() method", category="success")
1007
+
1008
+ # 5. Clean up BlockSwap-specific attributes
1009
+ for attr in ['_blockswap_runner_ref', 'blocks_to_swap', 'main_device',
1010
+ 'offload_device']:
1011
+ if hasattr(model, attr):
1012
+ delattr(model, attr)
1013
+
1014
+ # 6. Clean up runner attributes
1015
+ runner._blockswap_active = False
1016
+
1017
+ # Clean up pipelining resources on model (synchronize first)
1018
+ if hasattr(model, '_swap_stream'):
1019
+ try:
1020
+ model._swap_stream.synchronize()
1021
+ except Exception:
1022
+ pass
1023
+ for attr in ['_swap_stream', '_block_ready_events']:
1024
+ if hasattr(model, attr):
1025
+ delattr(model, attr)
1026
+
1027
+ # Remove all config attributes
1028
+ for attr in ['_cached_blockswap_config', '_block_swap_config', '_blockswap_debug']:
1029
+ if hasattr(runner, attr):
1030
+ delattr(runner, attr)
1031
+
1032
+ debug.log("BlockSwap cleanup complete", category="success")
src/optimization/blockswap.py.bak ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BlockSwap Module for SeedVR2
3
+
4
+ This module implements dynamic block swapping between GPU and CPU memory
5
+ to enable running large models on limited VRAM systems.
6
+
7
+ Key Features:
8
+ - Dynamic transformer block offloading during inference
9
+ - Non-blocking GPU transfers for optimal performance
10
+ - RoPE computation fallback to CPU on OOM
11
+ - Minimal performance overhead with intelligent caching
12
+ - I/O component offloading for maximum memory savings
13
+ """
14
+
15
+ import time
16
+ import types
17
+ import torch
18
+ import weakref
19
+
20
+ from typing import Dict, Any, List, Optional
21
+ from .memory_manager import clear_memory
22
+ from .compatibility import call_rope_with_stability
23
+ from ..common.distributed import get_device
24
+
25
+
26
+ def is_blockswap_enabled(config: Optional[Dict[str, Any]]) -> bool:
27
+ """
28
+ Check if BlockSwap configuration indicates BlockSwap should be enabled.
29
+
30
+ BlockSwap is enabled if either blocks_to_swap > 0 OR swap_io_components is True.
31
+ This is the authoritative function for determining BlockSwap status from configuration.
32
+
33
+ Args:
34
+ config: BlockSwap configuration dictionary with optional keys:
35
+ - blocks_to_swap: Number of blocks to offload (0 = disabled)
36
+ - swap_io_components: Whether to offload I/O components
37
+
38
+ Returns:
39
+ True if BlockSwap should be active, False otherwise
40
+ """
41
+ if not config:
42
+ return False
43
+
44
+ blocks_to_swap = config.get("blocks_to_swap", 0)
45
+ swap_io_components = config.get("swap_io_components", False)
46
+
47
+ return blocks_to_swap > 0 or swap_io_components
48
+
49
+
50
+ def validate_blockswap_config(
51
+ block_swap_config: Optional[Dict[str, Any]],
52
+ dit_device: 'torch.device',
53
+ dit_offload_device: Optional['torch.device'],
54
+ debug: 'Debug'
55
+ ) -> Optional[Dict[str, Any]]:
56
+ """
57
+ Validate and potentially modify BlockSwap configuration.
58
+
59
+ Performs platform-specific validation and configuration adjustment:
60
+ - On macOS (MPS): Auto-disables BlockSwap since unified memory makes it meaningless
61
+ - On other platforms: Validates that offload_device is properly configured
62
+
63
+ This is the single authoritative validation point for BlockSwap configuration,
64
+ called early in configure_runner() before any model loading.
65
+
66
+ Args:
67
+ block_swap_config: BlockSwap configuration dictionary (may be None)
68
+ dit_device: Target device for DiT model inference
69
+ dit_offload_device: Device for offloading DiT blocks (may be None)
70
+ debug: Debug instance for logging warnings/errors
71
+
72
+ Returns:
73
+ Validated/modified block_swap_config (may be None or modified copy)
74
+
75
+ Raises:
76
+ ValueError: If BlockSwap is enabled but offload_device is invalid (non-MPS only)
77
+ """
78
+ if not is_blockswap_enabled(block_swap_config):
79
+ return block_swap_config
80
+
81
+ blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
82
+ swap_io_components = block_swap_config.get("swap_io_components", False)
83
+
84
+ # Check for macOS unified memory - BlockSwap is meaningless there
85
+ if dit_device.type == "mps":
86
+ debug.log(
87
+ f"BlockSwap disabled: macOS uses unified memory (no separate VRAM/RAM). "
88
+ f"Ignoring blocks_to_swap={blocks_to_swap}, swap_io_components={swap_io_components}",
89
+ level="WARNING", category="blockswap", force=True
90
+ )
91
+ # Return disabled config
92
+ return {
93
+ **block_swap_config,
94
+ "blocks_to_swap": 0,
95
+ "swap_io_components": False
96
+ }
97
+
98
+ # Validate offload_device is set and different from dit_device
99
+ offload_device_valid = (
100
+ dit_offload_device is not None and
101
+ str(dit_offload_device) != str(dit_device)
102
+ )
103
+
104
+ if not offload_device_valid:
105
+ config_details = []
106
+ if blocks_to_swap > 0:
107
+ config_details.append(f"blocks_to_swap={blocks_to_swap}")
108
+ if swap_io_components:
109
+ config_details.append("swap_io_components=True")
110
+
111
+ offload_str = str(dit_offload_device) if dit_offload_device else "none"
112
+ raise ValueError(
113
+ f"BlockSwap enabled ({', '.join(config_details)}) but dit_offload_device is invalid. "
114
+ f"Current: device='{dit_device}', dit_offload_device='{offload_str}'. "
115
+ f"BlockSwap requires offload_device on the DiT Model to be set and different from device. "
116
+ f"Set --dit_offload_device cpu or disable BlockSwap."
117
+ )
118
+
119
+ return block_swap_config
120
+
121
+
122
+ # Timing helpers marked to skip torch.compile tracing
123
+ # These functions are excluded from Dynamo's graph tracing to avoid warnings
124
+ # about non-traceable builtins like time.time(), but they still execute normally
125
+ @torch._dynamo.disable
126
+ def _get_swap_start_time(debug, enabled: bool) -> Optional[float]:
127
+ """Get start time for swap operation if debug is enabled."""
128
+ return time.time() if debug and enabled else None
129
+
130
+
131
+ @torch._dynamo.disable
132
+ def _log_swap_timing(debug, t_start: Optional[float], component_id, component_type: str) -> None:
133
+ """Log swap timing if start time was captured."""
134
+ if debug and t_start is not None:
135
+ debug.log_swap_time(
136
+ component_id=component_id,
137
+ duration=time.time() - t_start,
138
+ component_type=component_type
139
+ )
140
+
141
+
142
+ def get_module_memory_mb(module: torch.nn.Module) -> float:
143
+ """
144
+ Calculate memory usage of a module in MB.
145
+
146
+ Args:
147
+ module: PyTorch module to measure
148
+
149
+ Returns:
150
+ Memory usage in megabytes
151
+ """
152
+ total_bytes = sum(
153
+ param.nelement() * param.element_size()
154
+ for param in module.parameters()
155
+ if param.data is not None
156
+ )
157
+ return total_bytes / (1024 * 1024)
158
+
159
+
160
+ def apply_block_swap_to_dit(
161
+ runner: 'VideoDiffusionInfer',
162
+ block_swap_config: Dict[str, Any],
163
+ debug: 'Debug'
164
+ ) -> None:
165
+ """
166
+ Apply block swapping configuration to a DiT model with OOM protection.
167
+
168
+ This is the main entry point for configuring block swapping on a model.
169
+ Handles block selection, I/O component offloading, device placement, and
170
+ forward method wrapping for dynamic memory management.
171
+
172
+ Args:
173
+ runner: VideoDiffusionInfer instance containing the model
174
+ block_swap_config: Configuration dictionary with keys:
175
+ - blocks_to_swap: Number of blocks to swap (from the start)
176
+ - swap_io_components: Whether to offload I/O components
177
+ - enable_debug: Whether to enable debug logging
178
+ - offload_device: Device to offload to (default: 'cpu')
179
+ debug: Debug instance for logging (required)
180
+ """
181
+ # Early return if BlockSwap not enabled
182
+ if not is_blockswap_enabled(block_swap_config):
183
+ return
184
+
185
+ blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
186
+ swap_io_components = block_swap_config.get("swap_io_components", False)
187
+
188
+ # Early return only if both block swap and I/O swap are disabled
189
+ if blocks_to_swap <= 0 and not swap_io_components:
190
+ return
191
+
192
+ if debug is None:
193
+ if hasattr(runner, 'debug') and runner.debug is not None:
194
+ debug = runner.debug
195
+ else:
196
+ raise ValueError("Debug instance must be provided to apply_block_swap_to_dit")
197
+
198
+ debug.start_timer("apply_blockswap")
199
+
200
+ # Get the actual model (handle CompatibleDiT wrapper)
201
+ model = runner.dit
202
+ if hasattr(model, "dit_model"):
203
+ model = model.dit_model
204
+
205
+ # Determine devices
206
+ if hasattr(runner, '_dit_device'):
207
+ device = runner._dit_device
208
+ else:
209
+ device = get_device()
210
+ offload_device = block_swap_config.get("offload_device", torch.device('cpu'))
211
+
212
+ # Validate model structure
213
+ if not hasattr(model, "blocks"):
214
+ debug.log("Model doesn't have 'blocks' attribute for BlockSwap", level="ERROR", category="blockswap", force=True)
215
+ return
216
+
217
+ total_blocks = len(model.blocks)
218
+
219
+ # Clamp blocks_to_swap to available blocks BEFORE logging
220
+ effective_blocks = min(blocks_to_swap, total_blocks) if blocks_to_swap > 0 else 0
221
+
222
+ # Log configuration clearly based on what's enabled
223
+ block_text = "block" if effective_blocks <= 1 else "blocks"
224
+ if effective_blocks > 0 and swap_io_components:
225
+ debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} + I/O components offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
226
+ elif effective_blocks > 0:
227
+ debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
228
+ elif swap_io_components:
229
+ debug.log(f"BlockSwap: I/O components offloaded to {str(offload_device).upper()} (0/{total_blocks} blocks swapped)", category="blockswap", force=True)
230
+
231
+ # Configure model with blockswap attributes
232
+ if blocks_to_swap > 0:
233
+ model.blocks_to_swap = effective_blocks - 1 # Convert to 0-indexed
234
+ else:
235
+ # No block swapping, set to -1 so no blocks match the swap condition
236
+ model.blocks_to_swap = -1
237
+
238
+ model.main_device = device
239
+ model.offload_device = offload_device
240
+
241
+ # Configure I/O components
242
+ io_config = _configure_io_components(model, device, offload_device,
243
+ swap_io_components, debug)
244
+ memory_stats = _configure_blocks(model, device, offload_device, debug)
245
+ memory_stats['io_components'] = io_config['components']
246
+ memory_stats['io_memory_mb'] = io_config['memory_mb']
247
+ memory_stats['gpu_components'] = io_config['gpu_components']
248
+ memory_stats['io_gpu_memory_mb'] = io_config['gpu_memory_mb']
249
+
250
+ # Log memory summary
251
+ _log_memory_summary(memory_stats, offload_device, device, swap_io_components,
252
+ debug)
253
+
254
+ # Wrap block forward methods for dynamic swapping (only if blocks_to_swap > 0)
255
+ if blocks_to_swap > 0:
256
+ for b, block in enumerate(model.blocks):
257
+ if b <= model.blocks_to_swap:
258
+ _wrap_block_forward(block, b, model, debug)
259
+
260
+ # Patch RoPE modules for robust error handling
261
+ _patch_rope_for_blockswap(model, debug)
262
+
263
+ # Mark BlockSwap as active
264
+ runner._blockswap_active = True
265
+
266
+ # Store configuration for debugging and cleanup
267
+ model._block_swap_config = {
268
+ "blocks_swapped": blocks_to_swap,
269
+ "swap_io_components": swap_io_components,
270
+ "total_blocks": total_blocks,
271
+ "offload_device": offload_device,
272
+ "main_device": device,
273
+ "offload_memory": memory_stats['offload_memory'],
274
+ "main_memory": memory_stats['main_memory']
275
+ }
276
+
277
+ # Protect model from being moved entirely
278
+ _protect_model_from_move(model, runner, debug)
279
+
280
+ debug.log("BlockSwap configuration complete", category="success")
281
+ debug.end_timer("apply_blockswap", "BlockSwap configuration application")
282
+
283
+
284
+ def _configure_io_components(
285
+ model: torch.nn.Module,
286
+ device: torch.device,
287
+ offload_device: torch.device,
288
+ swap_io_components: bool,
289
+ debug: 'Debug'
290
+ ) -> Dict[str, Any]:
291
+ """
292
+ Configure I/O component placement and wrapping with memory tracking.
293
+
294
+ Handles all non-block modules (embeddings, normalization layers, etc.) by
295
+ either keeping them on GPU or offloading them with dynamic swapping wrappers.
296
+
297
+ Args:
298
+ model: DiT model containing named children to configure
299
+ device: Main computation device (typically GPU)
300
+ offload_device: Device for offloaded components (typically CPU)
301
+ swap_io_components: If True, offload I/O components with dynamic swapping
302
+ debug: Debug instance for logging (required)
303
+
304
+ Returns:
305
+ Dictionary containing:
306
+ - components: List of offloaded component names
307
+ - memory_mb: Total memory of offloaded components in MB
308
+ - gpu_components: List of components remaining on GPU
309
+ - gpu_memory_mb: Total memory of GPU components in MB
310
+ """
311
+ io_components_offloaded = []
312
+ io_components_on_gpu = []
313
+ io_memory_mb = 0.0
314
+ io_gpu_memory_mb = 0.0
315
+
316
+ # Handle I/O modules with dynamic swapping
317
+ for name, module in model.named_children():
318
+ if name != "blocks":
319
+ module_memory = get_module_memory_mb(module)
320
+
321
+ if swap_io_components:
322
+ module.to(offload_device)
323
+ _wrap_io_forward(module, name, model, debug)
324
+ io_components_offloaded.append(name)
325
+ io_memory_mb += module_memory
326
+ debug.log(f"{name} → {str(offload_device).upper()} ({module_memory:.2f}MB, dynamic swapping)", category="blockswap", indent_level=1)
327
+ else:
328
+ module.to(device)
329
+ io_components_on_gpu.append(name)
330
+ io_gpu_memory_mb += module_memory
331
+ debug.log(f"{name} → {str(device).upper()} ({module_memory:.2f}MB)", category="blockswap", indent_level=1)
332
+
333
+ return {
334
+ 'components': io_components_offloaded,
335
+ 'memory_mb': io_memory_mb,
336
+ 'gpu_components': io_components_on_gpu,
337
+ 'gpu_memory_mb': io_gpu_memory_mb
338
+ }
339
+
340
+
341
+ def _configure_blocks(
342
+ model: torch.nn.Module,
343
+ device: torch.device,
344
+ offload_device: torch.device,
345
+ debug: 'Debug'
346
+ ) -> Dict[str, float]:
347
+ """
348
+ Configure transformer block placement and calculate memory statistics.
349
+
350
+ Moves blocks to their designated devices based on model.blocks_to_swap
351
+ attribute. Blocks with index <= blocks_to_swap go to offload device,
352
+ others stay on main device.
353
+
354
+ Args:
355
+ model: DiT model with blocks attribute and blocks_to_swap configured
356
+ device: Main computation device for non-swapped blocks
357
+ offload_device: Device for swapped blocks
358
+ debug: Debug instance for logging (required)
359
+
360
+ Returns:
361
+ Dictionary containing:
362
+ - offload_memory: Total memory of offloaded blocks in MB
363
+ - main_memory: Total memory of blocks on main device in MB
364
+ - io_components: Empty list (populated by caller)
365
+ """
366
+ total_offload_memory = 0.0
367
+ total_main_memory = 0.0
368
+
369
+ # Move blocks based on swap configuration
370
+ for b, block in enumerate(model.blocks):
371
+ block_memory = get_module_memory_mb(block)
372
+
373
+ if b > model.blocks_to_swap:
374
+ block.to(device)
375
+ total_main_memory += block_memory
376
+ else:
377
+ block.to(offload_device, non_blocking=False)
378
+ total_offload_memory += block_memory
379
+
380
+ # Ensure all buffers match their containing module's device
381
+ for b, block in enumerate(model.blocks):
382
+ target_device = device if b > model.blocks_to_swap else offload_device
383
+ for name, buffer in block.named_buffers():
384
+ if buffer.device != torch.device(target_device):
385
+ buffer.data = buffer.data.to(target_device, non_blocking=False)
386
+
387
+ return {
388
+ "offload_memory": total_offload_memory,
389
+ "main_memory": total_main_memory,
390
+ "io_components": [] # Will be populated by caller
391
+ }
392
+
393
+
394
+ def _log_memory_summary(
395
+ memory_stats: Dict[str, float],
396
+ offload_device: torch.device,
397
+ device: torch.device,
398
+ swap_io_components: bool,
399
+ debug: 'Debug'
400
+ ) -> None:
401
+ """
402
+ Log comprehensive memory usage summary for BlockSwap configuration.
403
+
404
+ Displays detailed breakdown of memory distribution across devices,
405
+ including transformer blocks and I/O components.
406
+
407
+ Args:
408
+ memory_stats: Dictionary containing:
409
+ - offload_memory: Memory offloaded from blocks (MB)
410
+ - main_memory: Memory remaining on main device (MB)
411
+ - io_memory_mb: Memory from offloaded I/O components (MB)
412
+ - io_gpu_memory_mb: Memory from I/O components on GPU (MB)
413
+ offload_device: Device used for offloading
414
+ device: Main computation device
415
+ swap_io_components: Whether I/O components are being swapped
416
+ debug: Debug instance for logging (required)
417
+ """
418
+ debug.log("BlockSwap memory configuration:", category="blockswap")
419
+
420
+ # Log transformer blocks memory
421
+ blocks_offloaded = memory_stats['offload_memory']
422
+ blocks_on_gpu = memory_stats['main_memory']
423
+
424
+ offload_str = str(offload_device)
425
+ device_str = str(device)
426
+
427
+ if blocks_on_gpu == 0:
428
+ debug.log(f"Transformer blocks: {blocks_offloaded:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
429
+ else:
430
+ debug.log(f"Transformer blocks: {blocks_on_gpu:.2f}MB on {device_str}, {blocks_offloaded:.2f}MB on {offload_str}", category="blockswap", indent_level=1)
431
+
432
+ # Always log I/O components (whether swapping or not)
433
+ io_memory = memory_stats.get('io_memory_mb', 0.0)
434
+ io_gpu_memory = memory_stats.get('io_gpu_memory_mb', 0.0)
435
+
436
+ if swap_io_components and io_memory > 0:
437
+ io_components = memory_stats.get('io_components', [])
438
+ debug.log(f"I/O components: {io_memory:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
439
+ debug.log(f"{', '.join(io_components)}", category="blockswap", indent_level=2)
440
+ elif io_gpu_memory > 0:
441
+ io_gpu_components = memory_stats.get('gpu_components', [])
442
+ debug.log(f"I/O components: {io_gpu_memory:.2f}MB on {device_str}", category="blockswap", indent_level=1)
443
+ debug.log(f"{', '.join(io_gpu_components)}", category="blockswap", indent_level=2)
444
+
445
+ # Log total VRAM savings
446
+ total_offloaded = blocks_offloaded + (io_memory if swap_io_components else 0)
447
+ if total_offloaded > 0:
448
+ debug.log(f"Total VRAM saved: {total_offloaded:.2f}MB (~{total_offloaded/1024:.2f}GB)", category="blockswap", indent_level=1)
449
+
450
+
451
+ def _wrap_block_forward(
452
+ block: torch.nn.Module,
453
+ block_idx: int,
454
+ model: torch.nn.Module,
455
+ debug: 'Debug'
456
+ ) -> None:
457
+ """
458
+ Wrap individual transformer block forward for dynamic device swapping.
459
+
460
+ Creates a wrapped forward method that automatically:
461
+ 1. Moves block to GPU before computation
462
+ 2. Executes original forward pass
463
+ 3. Moves block back to offload device after computation
464
+ 4. Logs timing and manages memory pressure
465
+
466
+ Uses weak references to prevent memory leaks from closure retention.
467
+
468
+ Args:
469
+ block: Individual transformer block to wrap
470
+ block_idx: Index of this block in model.blocks
471
+ model: Parent DiT model (used for device references)
472
+ debug: Debug instance for logging (required)
473
+ """
474
+ if hasattr(block, '_original_forward'):
475
+ return # Already wrapped
476
+
477
+ # Store original forward method
478
+ original_forward = block.forward
479
+
480
+ # Create weak references
481
+ model_ref = weakref.ref(model)
482
+ debug_ref = weakref.ref(debug)
483
+
484
+ # Store block_idx on the block itself to avoid closure issues
485
+ block._block_idx = block_idx
486
+
487
+ def wrapped_forward(self, *args, **kwargs):
488
+ # Retrieve weak references
489
+ model = model_ref()
490
+ debug = debug_ref()
491
+
492
+ if not model:
493
+ # Model has been garbage collected, fall back to original
494
+ return original_forward(*args, **kwargs)
495
+
496
+ # Check if block swap is active for this block
497
+ if hasattr(model, 'blocks_to_swap') and self._block_idx <= model.blocks_to_swap:
498
+ # Use dynamo-disabled helper to get start time (avoids compilation warnings)
499
+ t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
500
+
501
+ # Only move to GPU if necessary
502
+ current_device = next(self.parameters()).device
503
+ target_device = torch.device(model.main_device)
504
+
505
+ if current_device != target_device:
506
+ self.to(model.main_device, non_blocking=False)
507
+
508
+ # Execute forward pass with OOM protection
509
+ output = original_forward(*args, **kwargs)
510
+
511
+ # Move back to offload device
512
+ self.to(model.offload_device, non_blocking=False)
513
+
514
+ # Use dynamo-disabled helper to log timing (avoids compilation warnings)
515
+ _log_swap_timing(debug, t_start, self._block_idx, "block")
516
+
517
+ # Only clear cache under memory pressure
518
+ clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
519
+ else:
520
+ output = original_forward(*args, **kwargs)
521
+
522
+ return output
523
+
524
+ # Bind the wrapped function as a method to the block
525
+ block.forward = types.MethodType(wrapped_forward, block)
526
+
527
+ # Store reference to original forward for cleanup
528
+ block._original_forward = original_forward
529
+
530
+
531
+ def _wrap_io_forward(
532
+ module: torch.nn.Module,
533
+ module_name: str,
534
+ model: torch.nn.Module,
535
+ debug: 'Debug'
536
+ ) -> None:
537
+ """
538
+ Wrap I/O component forward for dynamic device swapping.
539
+
540
+ Similar to _wrap_block_forward but for I/O components (embeddings,
541
+ normalization layers, etc.). Handles swapping between GPU and CPU
542
+ during forward passes.
543
+
544
+ Uses weak references to prevent circular dependencies and memory leaks.
545
+
546
+ Args:
547
+ module: I/O component module to wrap
548
+ module_name: Name identifier for logging (e.g., 'x_embedder')
549
+ model: Parent DiT model (used for device references)
550
+ debug: Debug instance for logging (required)
551
+ """
552
+ if hasattr(module, '_is_io_wrapped') and module._is_io_wrapped:
553
+ debug.log(f"Reusing existing I/O wrapper for {module_name}", category="reuse")
554
+ return # Already wrapped
555
+
556
+ # Store original forward method
557
+ original_forward = module.forward
558
+
559
+ # Create weak references
560
+ model_ref = weakref.ref(model)
561
+ debug_ref = weakref.ref(debug) if debug else lambda: None
562
+
563
+ # Store module name on the module itself
564
+ module._module_name = module_name
565
+ module._original_forward = original_forward
566
+
567
+ def wrapped_io_forward(self, *args, **kwargs):
568
+ # Retrieve weak references
569
+ model = model_ref()
570
+ debug = debug_ref()
571
+
572
+ if not model:
573
+ # Model has been garbage collected, fall back to original
574
+ return self._original_forward(*args, **kwargs)
575
+
576
+ # Use dynamo-disabled helper to get start time (avoids compilation warnings)
577
+ t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
578
+
579
+ # Check current device to avoid unnecessary moves
580
+ current_device = next(self.parameters()).device
581
+ target_device = torch.device(model.main_device)
582
+
583
+ # Move to GPU for computation if needed
584
+ if current_device != target_device:
585
+ self.to(model.main_device, non_blocking=False)
586
+
587
+ # Execute forward pass
588
+ output = self._original_forward(*args, **kwargs)
589
+
590
+ # Move back to offload device
591
+ self.to(model.offload_device, non_blocking=False)
592
+
593
+ # Use dynamo-disabled helper to log timing (avoids compilation warnings)
594
+ _log_swap_timing(debug, t_start, self._module_name, "I/O")
595
+
596
+ # Only clear cache under memory pressure
597
+ clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
598
+
599
+ return output
600
+
601
+ # Bind as a method
602
+ module.forward = types.MethodType(wrapped_io_forward, module)
603
+ module._is_io_wrapped = True
604
+
605
+ # Store module reference for restoration
606
+ if not hasattr(model, '_io_swappers'):
607
+ model._io_swappers = []
608
+ model._io_swappers.append((module, module_name))
609
+
610
+
611
+ def _patch_rope_for_blockswap(
612
+ model: torch.nn.Module,
613
+ debug: 'Debug'
614
+ ) -> None:
615
+ """
616
+ Patch RoPE (Rotary Position Embedding) modules for device-aware fallback.
617
+
618
+ Adds CPU fallback logic to RoPE modules to handle device mismatch errors
619
+ that can occur during BlockSwap operations. Complements the stability
620
+ wrapper from compatibility.py with device-specific error handling.
621
+
622
+ Args:
623
+ model: DiT model containing RoPE modules to patch
624
+ debug: Debug instance for logging (required)
625
+ """
626
+ rope_patches = []
627
+
628
+ for name, module in model.named_modules():
629
+ if "rope" in name.lower() and hasattr(module, "get_axial_freqs"):
630
+ # Skip if already wrapped by blockswap
631
+ if hasattr(module, '_blockswap_wrapped') and module._blockswap_wrapped:
632
+ continue
633
+
634
+ # Get current method (might be stability-wrapped)
635
+ current_method = module.get_axial_freqs
636
+
637
+ # Create device-aware wrapper with proper closure handling
638
+ def make_device_aware_wrapper(module_name, current_fn):
639
+ def device_aware_rope_wrapper(self, *args, **kwargs):
640
+ try:
641
+ # Try current method (original or stability-wrapped)
642
+ return current_fn(*args, **kwargs)
643
+ except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
644
+ error_msg = str(e).lower()
645
+ # Only handle device/memory specific errors
646
+ if any(x in error_msg for x in ["device", "memory", "allocation"]):
647
+ debug.log(f"RoPE OOM for {module_name}", level="WARNING", category="rope", force=True)
648
+ debug.log(f"Clearing RoPE cache and retrying", category="info", force=True)
649
+
650
+ # Get current device from parameters
651
+ try:
652
+ current_device = next(self.parameters()).device
653
+ except StopIteration:
654
+ # Fallback: use model's main_device if BlockSwap has set it, else use offload_device
655
+ if hasattr(model, 'main_device'):
656
+ current_device = torch.device(model.main_device)
657
+ elif hasattr(model, 'offload_device'):
658
+ current_device = torch.device(model.offload_device)
659
+
660
+ # Try clearing cache first (non-invasive fix)
661
+ if hasattr(current_fn, 'cache_clear'):
662
+ current_fn.cache_clear()
663
+ try:
664
+ # Retry on same device after clearing cache
665
+ return current_fn(*args, **kwargs)
666
+ except Exception as retry_error:
667
+ # Cache clear wasn't enough, need more drastic measures
668
+ debug.log(f"Cache clear insufficient for {module_name}, falling back to CPU", level="WARNING", category="rope", force=True)
669
+
670
+ # Fallback to CPU computation with stability
671
+ self.cpu()
672
+
673
+ try:
674
+ # Use call_rope_with_stability for CPU computation
675
+ # This ensures cache is cleared and autocast disabled
676
+ original_fn = getattr(self, '_original_get_axial_freqs', current_fn)
677
+ result = call_rope_with_stability(original_fn, *args, **kwargs)
678
+
679
+ # Move module back to original device
680
+ self.to(current_device)
681
+
682
+ # Move result to appropriate device if it's a tensor
683
+ if hasattr(result, 'to'):
684
+ target_device = args[0].device if len(args) > 0 and hasattr(args[0], 'device') else current_device
685
+ return result.to(target_device)
686
+ return result
687
+
688
+ except Exception as cpu_error:
689
+ # Always restore device even on error
690
+ self.to(current_device)
691
+ raise cpu_error
692
+ else:
693
+ # Not a device error, let it bubble up
694
+ raise
695
+
696
+ return device_aware_rope_wrapper
697
+
698
+ # Apply wrapper
699
+ module.get_axial_freqs = types.MethodType(
700
+ make_device_aware_wrapper(name, current_method),
701
+ module
702
+ )
703
+ module._blockswap_wrapped = True
704
+
705
+ # Store for cleanup (use original or previously stored)
706
+ original_method = getattr(module, '_original_get_axial_freqs', current_method)
707
+ rope_patches.append((module, original_method))
708
+
709
+ if rope_patches:
710
+ model._rope_patches = rope_patches
711
+ debug.log(f"Patched {len(rope_patches)} RoPE modules with device handling", category="success")
712
+
713
+
714
+ def _protect_model_from_move(
715
+ model: torch.nn.Module,
716
+ runner: 'VideoDiffusionInfer',
717
+ debug: 'Debug'
718
+ ) -> None:
719
+ """
720
+ Protect model from unintended full device movement during BlockSwap.
721
+
722
+ Wraps model.to() method to prevent other code from accidentally moving
723
+ the entire model to GPU, which would defeat BlockSwap's memory savings.
724
+ Allows movement only when explicitly bypassed via model flag.
725
+
726
+ Args:
727
+ model: DiT model to protect
728
+ runner: VideoDiffusionInfer instance (for active status check)
729
+ debug: Debug instance for logging (required)
730
+ """
731
+ if not hasattr(model, '_original_to'):
732
+ # Store runner reference as weak reference to avoid circular refs
733
+ model._blockswap_runner_ref = weakref.ref(runner)
734
+ model._original_to = model.to
735
+
736
+ # Define the protected method without closures
737
+ def protected_model_to(self, device, *args, **kwargs):
738
+ # Check if protection is temporarily bypassed for offloading
739
+ # Flag is stored on model itself (not runner) to survive runner recreation
740
+ if getattr(self, "_blockswap_bypass_protection", False):
741
+ # Protection bypassed, allow movement
742
+ if hasattr(self, '_original_to'):
743
+ return self._original_to(device, *args, **kwargs)
744
+
745
+ # Get configured offload device directly from model
746
+ blockswap_offload_device = "cpu" # default
747
+ if hasattr(self, "_block_swap_config"):
748
+ blockswap_offload_device = self._block_swap_config.get("offload_device", "cpu")
749
+
750
+ # Check if BlockSwap is currently active via runner weak reference
751
+ runner_ref = getattr(self, '_blockswap_runner_ref', None)
752
+ blockswap_is_active = False
753
+ if runner_ref:
754
+ runner_obj = runner_ref()
755
+ if runner_obj and hasattr(runner_obj, "_blockswap_active"):
756
+ blockswap_is_active = runner_obj._blockswap_active
757
+
758
+ # Block attempts to move model away from configured offload device when active
759
+ if blockswap_is_active and str(device) != str(blockswap_offload_device):
760
+ # Get debug instance from runner if available
761
+ debug_instance = None
762
+ if runner_ref:
763
+ runner_obj = runner_ref()
764
+ if runner_obj and hasattr(runner_obj, 'debug'):
765
+ debug_instance = runner_obj.debug
766
+
767
+ if debug_instance:
768
+ debug_instance.log(
769
+ f"Blocked attempt to move BlockSwap model from {blockswap_offload_device} to {device}",
770
+ level="WARNING", category="blockswap", force=True
771
+ )
772
+ return self
773
+
774
+ # Allow movement (either bypass is enabled or target is offload device)
775
+ if hasattr(self, '_original_to'):
776
+ return self._original_to(device, *args, **kwargs)
777
+ else:
778
+ # Fallback - shouldn't happen
779
+ return super(type(self), self).to(device, *args, **kwargs)
780
+
781
+ # Bind as a method to the model instance
782
+ model.to = types.MethodType(protected_model_to, model)
783
+
784
+
785
+ def set_blockswap_bypass(runner, bypass: bool, debug):
786
+ """
787
+ Set or unset bypass flag for BlockSwap protection.
788
+ Used for offloading to temporarily allow model movement.
789
+
790
+ Args:
791
+ runner: Runner instance with BlockSwap
792
+ bypass: True to bypass protection, False to enforce it
793
+ debug: Debug instance for logging
794
+ """
795
+ if not hasattr(runner, "_blockswap_active") or not runner._blockswap_active:
796
+ return
797
+
798
+ # Get the actual model (handle CompatibleDiT wrapper)
799
+ model = runner.dit
800
+ if hasattr(model, "dit_model"):
801
+ model = model.dit_model
802
+
803
+ # Store on model so it survives runner recreation during caching
804
+ model._blockswap_bypass_protection = bypass
805
+
806
+ if bypass:
807
+ debug.log("BlockSwap protection disabled to allow model DiT offloading", category="success")
808
+ else:
809
+ debug.log("BlockSwap protection renabled to avoid accidentally offloading the entire DiT model", category="success")
810
+
811
+
812
+ def cleanup_blockswap(runner, keep_state_for_cache=False):
813
+ """
814
+ Clean up BlockSwap configuration based on caching mode.
815
+
816
+ When caching (keep_state_for_cache=True):
817
+ - Keep all BlockSwap configuration intact
818
+ - Only mark as inactive for safety during non-inference operations
819
+
820
+ When not caching (keep_state_for_cache=False):
821
+ - Full cleanup of all BlockSwap state
822
+
823
+ Args:
824
+ runner: VideoDiffusionInfer instance to clean up
825
+ keep_state_for_cache: If True, preserve BlockSwap state for reuse
826
+ """
827
+ # Get debug instance from runner
828
+ if not hasattr(runner, 'debug') or runner.debug is None:
829
+ raise ValueError("Debug instance must be available on runner for cleanup_blockswap")
830
+
831
+ debug = runner.debug
832
+
833
+ # Get the actual model (handle CompatibleDiT wrapper)
834
+ model = runner.dit
835
+ if hasattr(model, "dit_model"):
836
+ model = model.dit_model
837
+
838
+ # Check if there's any BlockSwap state to clean up (check both runner and model)
839
+ has_blockswap_state = (
840
+ hasattr(runner, "_blockswap_active") or
841
+ hasattr(model, "_block_swap_config") or
842
+ hasattr(model, "_blockswap_bypass_protection")
843
+ )
844
+
845
+ if not has_blockswap_state:
846
+ return
847
+
848
+ debug.log("Starting BlockSwap cleanup", category="cleanup")
849
+
850
+ if keep_state_for_cache:
851
+ # Minimal cleanup for caching - just mark as inactive and allow offloading
852
+ # Everything else stays intact for fast reactivation
853
+ if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
854
+ if not getattr(model, "_blockswap_bypass_protection", False):
855
+ set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
856
+ runner._blockswap_active = False
857
+ debug.log("BlockSwap deactivated for caching (configuration preserved)", category="success")
858
+ return
859
+
860
+ # Full cleanup when not caching
861
+ # Get the actual model (handle CompatibleDiT wrapper)
862
+ model = runner.dit
863
+ if hasattr(model, "dit_model"):
864
+ model = model.dit_model
865
+
866
+ # 1. Restore block forward methods
867
+ if hasattr(model, 'blocks'):
868
+ restored_count = 0
869
+ for block in model.blocks:
870
+ if hasattr(block, '_original_forward'):
871
+ block.forward = block._original_forward
872
+ delattr(block, '_original_forward')
873
+ restored_count += 1
874
+
875
+ # Clean up wrapper attributes
876
+ for attr in ['_block_idx', '_model_ref', '_debug_ref', '_blockswap_wrapped']:
877
+ if hasattr(block, attr):
878
+ delattr(block, attr)
879
+
880
+ if restored_count > 0:
881
+ debug.log(f"Restored {restored_count} block forward methods", category="success")
882
+
883
+ # 2. Restore RoPE patches
884
+ if hasattr(model, '_rope_patches'):
885
+ for module, original_method in model._rope_patches:
886
+ module.get_axial_freqs = original_method
887
+ # Clean up wrapper attributes
888
+ for attr in ['_rope_wrapped', '_original_get_axial_freqs']:
889
+ if hasattr(module, attr):
890
+ delattr(module, attr)
891
+ debug.log(f"Restored {len(model._rope_patches)} RoPE methods", category="success")
892
+ delattr(model, '_rope_patches')
893
+
894
+ # 3. Restore I/O component forward methods and move to offload device
895
+ if hasattr(model, '_io_swappers'):
896
+ for module, module_name in model._io_swappers:
897
+ if hasattr(module, '_original_forward'):
898
+ module.forward = module._original_forward
899
+ # Clean up wrapper attributes
900
+ for attr in ['_original_forward', '_model_ref', '_debug_ref',
901
+ '_module_name', '_is_io_wrapped']:
902
+ if hasattr(module, attr):
903
+ delattr(module, attr)
904
+ debug.log(f"Restored {len(model._io_swappers)} I/O components", category="success")
905
+ delattr(model, '_io_swappers')
906
+
907
+ # Move all IO components to offload device during full cleanup
908
+ if hasattr(model, 'offload_device'):
909
+ offload_device = model.offload_device
910
+ moved_count = 0
911
+ for name, module in model.named_children():
912
+ if name != "blocks":
913
+ module.to(offload_device)
914
+ moved_count += 1
915
+ if moved_count > 0:
916
+ debug.log(f"Moved {moved_count} IO components to offload device", category="success")
917
+
918
+ # 4. Restore original .to() method
919
+ if hasattr(model, '_original_to'):
920
+ model.to = model._original_to
921
+ delattr(model, '_original_to')
922
+ debug.log("Restored original .to() method", category="success")
923
+
924
+ # 5. Clean up BlockSwap-specific attributes
925
+ for attr in ['_blockswap_runner_ref', 'blocks_to_swap', 'main_device',
926
+ 'offload_device']:
927
+ if hasattr(model, attr):
928
+ delattr(model, attr)
929
+
930
+ # 6. Clean up runner attributes
931
+ runner._blockswap_active = False
932
+
933
+ # Remove all config attributes
934
+ for attr in ['_cached_blockswap_config', '_block_swap_config', '_blockswap_debug']:
935
+ if hasattr(runner, attr):
936
+ delattr(runner, attr)
937
+
938
+ debug.log("BlockSwap cleanup complete", category="success")
src/optimization/memory_manager.py ADDED
@@ -0,0 +1,1285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory management module for SeedVR2
3
+ Handles VRAM usage, cache management, and memory optimization
4
+
5
+ Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
6
+ """
7
+
8
+ import torch
9
+ import gc
10
+ import sys
11
+ import time
12
+ import psutil
13
+ import platform
14
+ from typing import Tuple, Dict, Any, Optional, List, Union
15
+
16
+
17
+ def _device_str(device: Union[torch.device, str]) -> str:
18
+ """Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
19
+ s = str(device).upper()
20
+ return 'MPS' if s.startswith('MPS') else s
21
+
22
+
23
+ def is_mps_available() -> bool:
24
+ """Check if MPS (Apple Metal) backend is available."""
25
+ return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
26
+
27
+
28
+ def is_cuda_available() -> bool:
29
+ """Check if CUDA backend is available."""
30
+ return torch.cuda.is_available()
31
+
32
+
33
+ def get_gpu_backend() -> str:
34
+ """Get the active GPU backend type.
35
+
36
+ Returns:
37
+ 'cuda': NVIDIA CUDA
38
+ 'mps': Apple Metal Performance Shaders
39
+ 'cpu': No GPU backend available
40
+ """
41
+ if is_cuda_available():
42
+ return 'cuda'
43
+ if is_mps_available():
44
+ return 'mps'
45
+ return 'cpu'
46
+
47
+
48
+ def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]:
49
+ """
50
+ Get list of available compute devices for SeedVR2
51
+
52
+ Args:
53
+ include_none: If True, prepend "none" to the device list (for offload options)
54
+ include_cpu: If True, include "cpu" in the device list (for offload options only)
55
+ Note: On MPS-only systems, "cpu" is automatically excluded since
56
+ unified memory architecture makes CPU offloading meaningless
57
+
58
+ Returns:
59
+ List of device strings (e.g., ["cuda:0", "cuda:1"] or ["none", "cpu", "cuda:0", "cuda:1"])
60
+ """
61
+ devs = []
62
+ has_cuda = False
63
+ has_mps = False
64
+
65
+ try:
66
+ if is_cuda_available():
67
+ devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
68
+ has_cuda = True
69
+ except Exception:
70
+ pass
71
+
72
+ try:
73
+ if is_mps_available():
74
+ devs.append("mps") # MPS doesn't use device indices
75
+ has_mps = True
76
+ except Exception:
77
+ pass
78
+
79
+ # Build result list with optional prefixes
80
+ result = []
81
+ if include_none:
82
+ result.append("none")
83
+
84
+ # Only include "cpu" option if:
85
+ # 1. It was requested (include_cpu=True), AND
86
+ # 2. Either CUDA is available OR MPS is not the only option
87
+ # Rationale: On MPS-only systems with unified memory architecture,
88
+ # CPU offloading is semantically meaningless as CPU and GPU share the same memory pool
89
+ if include_cpu and (has_cuda or not has_mps):
90
+ result.append("cpu")
91
+
92
+ result.extend(devs)
93
+
94
+ return result if result else []
95
+
96
+
97
+ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]:
98
+ """
99
+ Get basic VRAM availability info (free and total memory).
100
+ Used for capacity planning and initial checks.
101
+
102
+ Args:
103
+ device: Optional device to query. If None, uses cuda:0
104
+
105
+ Returns:
106
+ dict: {"free_gb": float, "total_gb": float} or {"error": str}
107
+ """
108
+ try:
109
+ if is_cuda_available():
110
+ if device is None:
111
+ device = torch.device("cuda:0")
112
+ elif not isinstance(device, torch.device):
113
+ device = torch.device(device)
114
+ free_memory, total_memory = torch.cuda.mem_get_info(device)
115
+ elif is_mps_available():
116
+ # MPS doesn't support per-device queries or mem_get_info
117
+ # Use system memory as proxy
118
+ mem = psutil.virtual_memory()
119
+ free_memory = mem.total - mem.used
120
+ total_memory = mem.total
121
+ else:
122
+ return {"error": "No GPU backend available (CUDA/MPS)"}
123
+
124
+ return {
125
+ "free_gb": free_memory / (1024**3),
126
+ "total_gb": total_memory / (1024**3)
127
+ }
128
+ except Exception as e:
129
+ return {"error": f"Failed to get memory info: {str(e)}"}
130
+
131
+
132
+ # Initial VRAM check at module load
133
+ vram_info = get_basic_vram_info(device=None)
134
+ if "error" not in vram_info:
135
+ backend = "MPS" if is_mps_available() else "CUDA"
136
+ print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total")
137
+ else:
138
+ print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
139
+
140
+
141
+ def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
142
+ """
143
+ Get current VRAM usage metrics for monitoring.
144
+ Used for tracking memory consumption during processing.
145
+
146
+ Args:
147
+ device: Optional device to query. If None, uses cuda:0
148
+ debug: Optional debug instance for logging
149
+
150
+ Returns:
151
+ tuple: (allocated_gb, reserved_gb, peak_allocated_gb, peak_reserved_gb)
152
+ Returns (0, 0, 0, 0) if no GPU available
153
+ """
154
+ try:
155
+ if is_cuda_available():
156
+ if device is None:
157
+ device = torch.device("cuda:0")
158
+ elif not isinstance(device, torch.device):
159
+ device = torch.device(device)
160
+ allocated = torch.cuda.memory_allocated(device) / (1024**3)
161
+ reserved = torch.cuda.memory_reserved(device) / (1024**3)
162
+ peak_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
163
+ peak_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
164
+ return allocated, reserved, peak_allocated, peak_reserved
165
+ elif is_mps_available():
166
+ # MPS doesn't support per-device queries - uses global memory tracking
167
+ allocated = torch.mps.current_allocated_memory() / (1024**3)
168
+ reserved = torch.mps.driver_allocated_memory() / (1024**3)
169
+ # MPS doesn't track peak separately
170
+ return allocated, reserved, allocated, reserved
171
+ except Exception as e:
172
+ if debug:
173
+ debug.log(f"Failed to get VRAM usage: {e}", level="WARNING", category="memory", force=True)
174
+ return 0.0, 0.0, 0.0, 0.0
175
+
176
+
177
+ def get_ram_usage(debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
178
+ """
179
+ Get current RAM usage metrics for the current process.
180
+ Provides accurate tracking of process-specific memory consumption.
181
+
182
+ Args:
183
+ debug: Optional debug instance for logging
184
+
185
+ Returns:
186
+ tuple: (process_gb, available_gb, total_gb, used_by_others_gb)
187
+ Returns (0, 0, 0, 0) if psutil not available or on error
188
+ """
189
+ try:
190
+ if not psutil:
191
+ return 0.0, 0.0, 0.0, 0.0
192
+
193
+ # Get current process memory
194
+ process = psutil.Process()
195
+ process_memory = process.memory_info()
196
+ process_gb = process_memory.rss / (1024**3)
197
+
198
+ # Get system memory
199
+ sys_memory = psutil.virtual_memory()
200
+ total_gb = sys_memory.total / (1024**3)
201
+ available_gb = sys_memory.available / (1024**3)
202
+
203
+ # Calculate memory used by other processes
204
+ # This is the CORRECT calculation:
205
+ total_used_gb = total_gb - available_gb # Total memory used by ALL processes
206
+ used_by_others_gb = max(0, total_used_gb - process_gb) # Subtract current process
207
+
208
+ return process_gb, available_gb, total_gb, used_by_others_gb
209
+
210
+ except Exception as e:
211
+ if debug:
212
+ debug.log(f"Failed to get RAM usage: {e}", level="WARNING", category="memory", force=True)
213
+ return 0.0, 0.0, 0.0, 0.0
214
+
215
+
216
+ # Global cache for OS libraries (initialized once)
217
+ _os_memory_lib = None
218
+
219
+
220
+ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: bool = True,
221
+ timer_name: Optional[str] = None) -> None:
222
+ """
223
+ Clear memory caches with two-tier approach for optimal performance.
224
+
225
+ Args:
226
+ debug: Debug instance for logging (optional)
227
+ force: If True, always clear. If False, only clear when <5% free
228
+ deep: If True, perform deep cleanup including GC and OS operations.
229
+ If False (default), only perform minimal GPU cache clearing.
230
+ timer_name: Optional suffix for timer names to make them unique per invocation
231
+
232
+ Two-tier approach:
233
+ - Minimal mode (deep=False): GPU cache operations (~1-5ms)
234
+ Used for frequent calls during batch processing
235
+ - Deep mode (deep=True): Complete cleanup with GC and OS operations (~10-50ms)
236
+ Used at key points like model switches or final cleanup
237
+ """
238
+ global _os_memory_lib
239
+
240
+ # Create unique timer names if suffix provided
241
+ if timer_name:
242
+ main_timer = f"memory_clear_{timer_name}"
243
+ gpu_timer = f"gpu_cache_clear_{timer_name}"
244
+ gc_timer = f"garbage_collection_{timer_name}"
245
+ os_timer = f"os_memory_release_{timer_name}"
246
+ completion_msg = f"clear_memory() completion ({timer_name})"
247
+ else:
248
+ main_timer = "memory_clear"
249
+ gpu_timer = "gpu_cache_clear"
250
+ gc_timer = "garbage_collection"
251
+ os_timer = "os_memory_release"
252
+ completion_msg = "clear_memory() completion"
253
+
254
+ # Start timer for entire operation
255
+ if debug:
256
+ debug.start_timer(main_timer)
257
+
258
+ # Check if we should clear based on memory pressure
259
+ if not force:
260
+ should_clear = False
261
+
262
+ # Use existing function for memory info
263
+ mem_info = get_basic_vram_info(device=None)
264
+
265
+ if "error" not in mem_info and mem_info["total_gb"] > 0:
266
+ # Check VRAM/MPS memory pressure (5% free threshold)
267
+ free_ratio = mem_info["free_gb"] / mem_info["total_gb"]
268
+ if free_ratio < 0.05:
269
+ should_clear = True
270
+ if debug:
271
+ backend = "Unified Memory" if is_mps_available() else "VRAM"
272
+ debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory")
273
+
274
+ # For non-MPS systems, also check system RAM separately
275
+ if not should_clear and not is_mps_available():
276
+ mem = psutil.virtual_memory()
277
+ if mem.available < mem.total * 0.05:
278
+ should_clear = True
279
+ if debug:
280
+ debug.log(f"RAM pressure: {mem.available/(1024**3):.2f}GB free of {mem.total/(1024**3):.2f}GB", category="memory")
281
+
282
+ if not should_clear:
283
+ # End timer before early return to keep stack clean
284
+ if debug:
285
+ debug.end_timer(main_timer)
286
+ return
287
+
288
+ # Determine cleanup level
289
+ cleanup_mode = "deep" if deep else "minimal"
290
+ if debug:
291
+ debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup")
292
+
293
+ # ===== MINIMAL OPERATIONS (Always performed) =====
294
+ # Step 1: Clear GPU caches - Fast operations (~1-5ms)
295
+ if debug:
296
+ debug.start_timer(gpu_timer)
297
+
298
+ if is_cuda_available():
299
+ torch.cuda.empty_cache()
300
+ torch.cuda.ipc_collect()
301
+ elif is_mps_available():
302
+ torch.mps.empty_cache()
303
+
304
+ if debug:
305
+ debug.end_timer(gpu_timer, "GPU cache clearing")
306
+
307
+ # ===== DEEP OPERATIONS (Only when deep=True) =====
308
+ if deep:
309
+ # Step 2: Deep garbage collection (expensive ~5-20ms)
310
+ if debug:
311
+ debug.start_timer(gc_timer)
312
+
313
+ gc.collect(2)
314
+
315
+ if debug:
316
+ debug.end_timer(gc_timer, "Garbage collection")
317
+
318
+ # Step 3: Return memory to OS (platform-specific, ~5-30ms)
319
+ if debug:
320
+ debug.start_timer(os_timer)
321
+
322
+ try:
323
+ if sys.platform == 'linux':
324
+ # Linux: malloc_trim
325
+ import ctypes # Import only when needed
326
+ if _os_memory_lib is None:
327
+ _os_memory_lib = ctypes.CDLL("libc.so.6")
328
+ _os_memory_lib.malloc_trim(0)
329
+
330
+ elif sys.platform == 'win32':
331
+ # Windows: Trim working set
332
+ import ctypes # Import only when needed
333
+ if _os_memory_lib is None:
334
+ _os_memory_lib = ctypes.windll.kernel32
335
+ handle = _os_memory_lib.GetCurrentProcess()
336
+ _os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1)
337
+
338
+ elif is_mps_available():
339
+ # macOS with MPS
340
+ import ctypes # Import only when needed
341
+ import ctypes.util
342
+ if _os_memory_lib is None:
343
+ libc_path = ctypes.util.find_library('c')
344
+ if libc_path:
345
+ _os_memory_lib = ctypes.CDLL(libc_path)
346
+
347
+ if _os_memory_lib:
348
+ _os_memory_lib.sync()
349
+ except Exception as e:
350
+ if debug:
351
+ debug.log(f"Failed to perform OS memory operations: {e}", level="WARNING", category="memory", force=True)
352
+
353
+ if debug:
354
+ debug.end_timer(os_timer, "OS memory release")
355
+
356
+ # End overall timer
357
+ if debug:
358
+ debug.end_timer(main_timer, completion_msg)
359
+
360
+
361
+ def retry_on_oom(func, *args, debug=None, operation_name="operation", **kwargs):
362
+ """
363
+ Execute function with single OOM retry after memory cleanup.
364
+
365
+ Args:
366
+ func: Callable to execute
367
+ *args: Positional arguments for func
368
+ debug: Debug instance for logging (optional)
369
+ operation_name: Name for logging
370
+ **kwargs: Keyword arguments for func
371
+
372
+ Returns:
373
+ Result of func(*args, **kwargs)
374
+ """
375
+ try:
376
+ return func(*args, **kwargs)
377
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
378
+ # Only handle OOM errors
379
+ if not any(x in str(e).lower() for x in ["out of memory", "allocation on device"]):
380
+ raise
381
+
382
+ if debug:
383
+ debug.log(f"OOM during {operation_name}: {e}", level="WARNING", category="memory", force=True)
384
+ debug.log(f"Clearing memory and retrying", category="info", force=True)
385
+
386
+ # Clear memory
387
+ clear_memory(debug=debug, deep=True, force=True, timer_name=operation_name)
388
+ # Let memory settle
389
+ time.sleep(0.5)
390
+ debug.log_memory_state("After memory clearing", show_tensors=False, detailed_tensors=False)
391
+
392
+ # Single retry
393
+ try:
394
+ result = func(*args, **kwargs)
395
+ if debug:
396
+ debug.log(f"Retry successful for {operation_name}", category="success", force=True)
397
+ return result
398
+ except Exception as retry_e:
399
+ if debug:
400
+ debug.log(f"Retry failed for {operation_name}: {retry_e}", level="ERROR", category="memory", force=True)
401
+ raise
402
+
403
+
404
+ def reset_vram_peak(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> None:
405
+ """
406
+ Reset VRAM peak memory statistics for fresh tracking.
407
+
408
+ Args:
409
+ device: Optional device to reset stats for. If None, uses cuda:0
410
+ debug: Optional debug instance for logging
411
+ """
412
+ if debug and debug.enabled:
413
+ debug.log("Resetting VRAM peak memory statistics", category="memory")
414
+ try:
415
+ if is_cuda_available():
416
+ if device is None:
417
+ device = torch.device("cuda:0")
418
+ elif not isinstance(device, torch.device):
419
+ device = torch.device(device)
420
+ torch.cuda.reset_peak_memory_stats(device)
421
+ # Note: MPS doesn't support peak memory reset - no action needed
422
+ except Exception as e:
423
+ if debug and debug.enabled:
424
+ debug.log(f"Failed to reset peak memory stats: {e}", level="WARNING", category="memory", force=True)
425
+
426
+
427
+ def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> int:
428
+ """
429
+ Clear ALL LRU caches from RoPE modules.
430
+
431
+ Args:
432
+ model: PyTorch model to clear caches from
433
+ debug: Optional debug instance for logging
434
+
435
+ Returns:
436
+ Number of caches cleared
437
+ """
438
+ if model is None:
439
+ return 0
440
+
441
+ cleared_count = 0
442
+ try:
443
+ for name, module in model.named_modules():
444
+ if hasattr(module, 'get_axial_freqs') and hasattr(module.get_axial_freqs, 'cache_clear'):
445
+ try:
446
+ module.get_axial_freqs.cache_clear()
447
+ cleared_count += 1
448
+ except Exception as e:
449
+ if debug:
450
+ debug.log(f"Failed to clear RoPE LRU cache for module {name}: {e}", level="WARNING", category="memory", force=True)
451
+ except (AttributeError, RuntimeError) as e:
452
+ if debug:
453
+ debug.log(f"Failed to iterate model modules for RoPE LRU cache clearing: {e}", level="WARNING", category="memory", force=True)
454
+
455
+ return cleared_count
456
+
457
+
458
+ def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None:
459
+ """Release tensor memory from any device (CPU/CUDA/MPS)"""
460
+ if tensor is not None and torch.is_tensor(tensor):
461
+ # Release storage for all devices (CPU, CUDA, MPS)
462
+ if tensor.numel() > 0:
463
+ tensor.data.set_()
464
+ tensor.grad = None
465
+
466
+
467
+ def release_tensor_collection(collection: Any, recursive: bool = True) -> None:
468
+ """
469
+ Release GPU memory from tensors in any collection (list, tuple, dict, or single tensor).
470
+
471
+ Args:
472
+ collection: Tensor, list, tuple, dict, or nested structure to release
473
+ recursive: If True, handle nested structures recursively
474
+
475
+ Examples:
476
+ release_tensor_collection(tensor) # Single tensor
477
+ release_tensor_collection([tensor1, tensor2]) # List of tensors
478
+ release_tensor_collection([[t1, t2], [t3, t4]]) # Nested lists
479
+ release_tensor_collection({'a': tensor}) # Dict values
480
+ """
481
+ if collection is None:
482
+ return
483
+
484
+ if torch.is_tensor(collection):
485
+ release_tensor_memory(collection)
486
+ elif isinstance(collection, dict):
487
+ for value in collection.values():
488
+ if recursive:
489
+ release_tensor_collection(value, recursive=True)
490
+ elif torch.is_tensor(value):
491
+ release_tensor_memory(value)
492
+ elif isinstance(collection, (list, tuple)):
493
+ for item in collection:
494
+ if recursive:
495
+ release_tensor_collection(item, recursive=True)
496
+ elif torch.is_tensor(item):
497
+ release_tensor_memory(item)
498
+
499
+
500
+ def release_text_embeddings(*embeddings: torch.Tensor, debug: Optional['Debug'] = None, names: Optional[List[str]] = None) -> None:
501
+ """
502
+ Release memory for text embeddings
503
+
504
+ Args:
505
+ *embeddings: Variable number of embedding tensors to release
506
+ debug: Optional debug instance for logging
507
+ names: Optional list of names for logging
508
+ """
509
+ for i, embedding in enumerate(embeddings):
510
+ if embedding is not None:
511
+ release_tensor_memory(embedding)
512
+ if debug and names and i < len(names):
513
+ debug.log(f"Cleaned up {names[i]}", category="cleanup")
514
+
515
+
516
+ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None) -> None:
517
+ """
518
+ Clean up text embeddings from a context dictionary.
519
+ Extracts embeddings, releases memory, and clears the context entry.
520
+
521
+ Args:
522
+ ctx: Context dictionary potentially containing 'text_embeds'
523
+ debug: Optional debug instance for logging
524
+ """
525
+ if not ctx or not ctx.get('text_embeds'):
526
+ return
527
+
528
+ embeddings = []
529
+ names = []
530
+ for key, embeds_list in ctx['text_embeds'].items():
531
+ if embeds_list:
532
+ embeddings.extend(embeds_list)
533
+ names.append(key)
534
+
535
+ if embeddings:
536
+ release_text_embeddings(embeddings, names, debug)
537
+
538
+ if debug:
539
+ debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup")
540
+
541
+ ctx['text_embeds'] = None
542
+
543
+
544
+ def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None:
545
+ """
546
+ Release all GPU/MPS memory from model in-place without CPU transfer.
547
+
548
+ Args:
549
+ model: PyTorch model to release memory from
550
+ debug: Optional debug instance for logging
551
+ """
552
+ if model is None:
553
+ return
554
+
555
+ # If the model has pipelining resources (swap stream), synchronize to ensure no pending async ops
556
+ try:
557
+ if hasattr(model, "_swap_stream"):
558
+ try:
559
+ model._swap_stream.synchronize()
560
+ except Exception:
561
+ if debug:
562
+ debug.log("Failed to synchronize model._swap_stream before releasing memory", level="WARNING", category="memory", force=True)
563
+ except Exception:
564
+ pass
565
+
566
+ try:
567
+ # Clear gradients first
568
+ model.zero_grad(set_to_none=True)
569
+
570
+ # Release GPU memory directly without CPU transfer
571
+ released_params = 0
572
+ released_buffers = 0
573
+
574
+ for param in model.parameters():
575
+ if param.is_cuda or param.is_mps:
576
+ if param.numel() > 0:
577
+ param.data.set_()
578
+ released_params += 1
579
+ param.grad = None
580
+
581
+ for buffer in model.buffers():
582
+ if buffer.is_cuda or buffer.is_mps:
583
+ if buffer.numel() > 0:
584
+ buffer.data.set_()
585
+ released_buffers += 1
586
+
587
+ if debug and (released_params > 0 or released_buffers > 0):
588
+ debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success")
589
+
590
+ except (AttributeError, RuntimeError) as e:
591
+ if debug:
592
+ debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True)
593
+
594
+
595
+ def manage_tensor(
596
+ tensor: torch.Tensor,
597
+ target_device: torch.device,
598
+ tensor_name: str = "tensor",
599
+ dtype: Optional[torch.dtype] = None,
600
+ non_blocking: bool = False,
601
+ debug: Optional['Debug'] = None,
602
+ reason: Optional[str] = None,
603
+ indent_level: int = 0
604
+ ) -> torch.Tensor:
605
+ """
606
+ Unified tensor management for device movement and dtype conversion.
607
+
608
+ Handles both device transfers (CPU ↔ GPU) and dtype conversions (e.g., float16 → bfloat16)
609
+ with intelligent early-exit optimization and comprehensive logging.
610
+
611
+ Args:
612
+ tensor: Tensor to manage
613
+ target_device: Target device (torch.device object)
614
+ tensor_name: Descriptive name for logging (e.g., "latent", "sample", "alpha_channel")
615
+ dtype: Optional target dtype to cast to (if None, keeps original dtype)
616
+ non_blocking: Whether to use non-blocking transfer
617
+ debug: Debug instance for logging
618
+ reason: Optional reason for the operation (e.g., "inference", "offload", "dtype alignment")
619
+ indent_level: Indentation level for debug logging (0=no indent, 1=2 spaces, etc.)
620
+
621
+ Returns:
622
+ Tensor on target device with optional dtype conversion
623
+
624
+ Note:
625
+ - Skips operation if tensor already has target device and dtype (zero-copy)
626
+ - Uses PyTorch's optimized .to() for efficient device/dtype handling
627
+ - Logs all operations consistently for tracking and debugging
628
+ """
629
+ if tensor is None:
630
+ return tensor
631
+
632
+ # Get current state
633
+ current_device = tensor.device
634
+ current_dtype = tensor.dtype
635
+ target_dtype = dtype if dtype is not None else current_dtype
636
+
637
+ # Check if movement is actually needed
638
+ needs_device_move = _device_str(current_device) != _device_str(target_device)
639
+ needs_dtype_change = dtype is not None and current_dtype != target_dtype
640
+
641
+ if not needs_device_move and not needs_dtype_change:
642
+ # Already on target device and dtype - skip
643
+ return tensor
644
+
645
+ # Determine reason for movement
646
+ if reason is None:
647
+ if needs_device_move and needs_dtype_change:
648
+ reason = "device and dtype conversion"
649
+ elif needs_device_move:
650
+ reason = "device movement"
651
+ else:
652
+ reason = "dtype conversion"
653
+
654
+ # Log the movement
655
+ if debug:
656
+ current_device_str = _device_str(current_device)
657
+ target_device_str = _device_str(target_device)
658
+
659
+ dtype_info = ""
660
+ if needs_dtype_change:
661
+ dtype_info = f", {current_dtype} → {target_dtype}"
662
+
663
+ debug.log(
664
+ f"Moving {tensor_name} from {current_device_str} to {target_device_str}{dtype_info} ({reason})",
665
+ category="general",
666
+ indent_level=indent_level
667
+ )
668
+
669
+ # Perform the operation based on what needs to change
670
+ if needs_device_move and needs_dtype_change:
671
+ # Both device and dtype need to change
672
+ return tensor.to(target_device, dtype=target_dtype, non_blocking=non_blocking)
673
+ elif needs_device_move:
674
+ # Only device needs to change
675
+ return tensor.to(target_device, non_blocking=non_blocking)
676
+ else:
677
+ # Only dtype needs to change
678
+ return tensor.to(dtype=target_dtype)
679
+
680
+
681
+ def manage_model_device(model: torch.nn.Module, target_device: torch.device, model_name: str,
682
+ debug: Optional['Debug'] = None, reason: Optional[str] = None,
683
+ runner: Optional[Any] = None) -> bool:
684
+ """
685
+ Move model to target device with optimizations.
686
+ Handles BlockSwap-enabled models transparently.
687
+
688
+ Args:
689
+ model: The model to move
690
+ target_device: Target device (torch.device object, e.g., torch.device('cuda:0'))
691
+ model_name: Name for logging (e.g., "VAE", "DiT")
692
+ debug: Debug instance for logging
693
+ reason: Optional custom reason for the movement
694
+ runner: Optional runner instance for BlockSwap detection
695
+
696
+ Returns:
697
+ bool: True if model was moved, False if already on target device
698
+ """
699
+ if model is None:
700
+ return False
701
+
702
+ # Check if this is a BlockSwap-enabled DiT model
703
+ is_blockswap_model = False
704
+ actual_model = model
705
+ if runner and model_name == "DiT":
706
+ # Import here to avoid circular dependency
707
+ from .blockswap import is_blockswap_enabled
708
+ # Check if BlockSwap config exists and is enabled
709
+ has_blockswap_config = (
710
+ hasattr(runner, '_dit_block_swap_config') and
711
+ is_blockswap_enabled(runner._dit_block_swap_config)
712
+ )
713
+
714
+ if has_blockswap_config:
715
+ is_blockswap_model = True
716
+ # Get the actual model (handle CompatibleDiT wrapper)
717
+ if hasattr(model, "dit_model"):
718
+ actual_model = model.dit_model
719
+
720
+ # Get current device
721
+ try:
722
+ current_device = next(model.parameters()).device
723
+ except StopIteration:
724
+ return False
725
+
726
+ # Extract device type for comparison (both are torch.device objects)
727
+ target_type = target_device.type
728
+ current_device_upper = _device_str(current_device)
729
+ target_device_upper = _device_str(target_device)
730
+
731
+ # Compare normalized device types
732
+ if current_device_upper == target_device_upper and not is_blockswap_model:
733
+ # Already on target device type, no movement needed
734
+ if debug:
735
+ debug.log(f"{model_name} already on {current_device_upper}, skipping movement", category="general")
736
+ return False
737
+
738
+ # Handle BlockSwap models specially
739
+ if is_blockswap_model:
740
+ return _handle_blockswap_model_movement(
741
+ runner, actual_model, current_device, target_device, target_type,
742
+ model_name, debug, reason
743
+ )
744
+
745
+ # Standard model movement (non-BlockSwap)
746
+ return _standard_model_movement(
747
+ model, current_device, target_device, target_type, model_name,
748
+ debug, reason
749
+ )
750
+
751
+
752
+ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
753
+ current_device: torch.device, target_device: torch.device,
754
+ target_type: str, model_name: str,
755
+ debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
756
+ """
757
+ Handle device movement for BlockSwap-enabled models.
758
+
759
+ Args:
760
+ runner: Runner instance with BlockSwap configuration
761
+ model: Model to move (actual unwrapped model)
762
+ current_device: Current device of the model
763
+ target_device: Target device (torch.device object)
764
+ target_type: Target device type (cpu/cuda/mps)
765
+ model_name: Model name for logging
766
+ debug: Debug instance
767
+ reason: Movement reason
768
+
769
+ Returns:
770
+ bool: True if model was moved
771
+ """
772
+ # Import here to avoid circular dependency
773
+ from .blockswap import set_blockswap_bypass
774
+
775
+ if target_type == "cpu":
776
+ # Moving to offload device (typically CPU)
777
+ # Check if any parameter is on GPU (for accurate logging)
778
+ actual_source_device = None
779
+ for param in model.parameters():
780
+ if param.device.type in ['cuda', 'mps']:
781
+ actual_source_device = param.device
782
+ break
783
+
784
+ source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device)
785
+
786
+ if debug:
787
+ debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general")
788
+
789
+ # Enable bypass to allow movement
790
+ set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
791
+
792
+ # If a pipelined swap stream exists, synchronize it to ensure no pending async transfers
793
+ if hasattr(model, "_swap_stream"):
794
+ try:
795
+ model._swap_stream.synchronize()
796
+ except Exception:
797
+ # Best-effort; don't fail the movement if synchronize not supported
798
+ if debug:
799
+ debug.log("Failed to synchronize model._swap_stream before offload", level="WARNING", category="memory", force=True)
800
+
801
+ # Start timer
802
+ timer_name = f"{model_name.lower()}_to_{target_type}"
803
+ if debug:
804
+ debug.start_timer(timer_name)
805
+
806
+ # Move entire model to target offload device
807
+ model.to(target_device)
808
+ model.zero_grad(set_to_none=True)
809
+
810
+ # After moving to CPU, attempt to pin CPU tensors to enable non-blocking async copies later.
811
+ try:
812
+ for p in model.parameters():
813
+ if p.device.type == "cpu" and p.numel() > 0 and not p.data.is_pinned():
814
+ p.data = p.data.pin_memory()
815
+ for b in model.buffers():
816
+ if b.device.type == "cpu" and b.numel() > 0 and not b.data.is_pinned():
817
+ b.data = b.data.pin_memory()
818
+ except Exception as e:
819
+ # Pinning is best-effort; log and continue
820
+ if debug:
821
+ debug.log(f"Pin-memory on offloaded model failed: {e}", level="WARNING", category="memory", force=True)
822
+
823
+ if debug:
824
+ debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
825
+
826
+ # Move entire model to target offload device
827
+ model.to(target_device)
828
+ model.zero_grad(set_to_none=True)
829
+
830
+ if debug:
831
+ debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
832
+
833
+ return True
834
+
835
+ else:
836
+ # Moving to GPU (reload)
837
+ # Check if we're in bypass mode (coming from offload)
838
+ if not getattr(model, "_blockswap_bypass_protection", False):
839
+ # Not in bypass mode, blocks are already configured
840
+ if debug:
841
+ debug.log(f"{model_name} with BlockSwap active - blocks already distributed across devices, skipping movement", category="general")
842
+ return False
843
+
844
+ # Get actual current device for accurate logging
845
+ actual_current_device = None
846
+ for param in model.parameters():
847
+ if param.device.type != 'meta':
848
+ actual_current_device = param.device
849
+ break
850
+
851
+ current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD"
852
+
853
+ if debug:
854
+ debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general")
855
+
856
+ timer_name = f"{model_name.lower()}_to_gpu"
857
+ if debug:
858
+ debug.start_timer(timer_name)
859
+
860
+ # Restore blocks to their configured devices
861
+ if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"):
862
+ # Use configured offload_device from BlockSwap config
863
+ offload_device = model._block_swap_config.get("offload_device")
864
+ if not offload_device:
865
+ raise ValueError("BlockSwap config missing offload_device")
866
+
867
+ # Move blocks according to BlockSwap configuration
868
+ for b, block in enumerate(model.blocks):
869
+ if b > model.blocks_to_swap:
870
+ # This block should be on GPU
871
+ block.to(target_device)
872
+ else:
873
+ # This block stays on offload device (will be swapped during forward)
874
+ block.to(offload_device)
875
+
876
+ # Handle I/O components
877
+ if not model._block_swap_config.get("swap_io_components", False):
878
+ # I/O components should be on GPU if not offloaded
879
+ for name, module in model.named_children():
880
+ if name != "blocks":
881
+ module.to(target_device)
882
+ else:
883
+ # I/O components stay on offload device
884
+ for name, module in model.named_children():
885
+ if name != "blocks":
886
+ module.to(offload_device)
887
+
888
+ if debug:
889
+ # Get actual configuration from runner
890
+ if hasattr(model, '_block_swap_config'):
891
+ blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16)
892
+ total_blocks = model._block_swap_config.get('total_blocks', 32)
893
+ main_device = model._block_swap_config.get('main_device', 'GPU')
894
+ debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success")
895
+ else:
896
+ debug.log("BlockSwap blocks restored to configured devices", category="success")
897
+
898
+
899
+ # Reactivate BlockSwap now that blocks are restored to their configured devices
900
+ runner._blockswap_active = True
901
+
902
+ # Disable bypass, re-enable protection
903
+ set_blockswap_bypass(runner=runner, bypass=False, debug=debug)
904
+
905
+ if debug:
906
+ debug.end_timer(timer_name, "BlockSwap model restored")
907
+
908
+ return True
909
+
910
+
911
+ def _standard_model_movement(model: torch.nn.Module, current_device: torch.device,
912
+ target_device: torch.device, target_type: str, model_name: str,
913
+ debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
914
+ """
915
+ Handle standard (non-BlockSwap) model movement.
916
+
917
+ Args:
918
+ model: Model to move
919
+ current_device: Current device of the model
920
+ target_device: Target device (torch.device object)
921
+ target_type: Target device type
922
+ model_name: Model name for logging
923
+ debug: Debug instance
924
+ reason: Movement reason
925
+
926
+ Returns:
927
+ bool: True if model was moved
928
+ """
929
+ # Check if model is on meta device - can't move meta tensors
930
+ if current_device.type == 'meta':
931
+ if debug:
932
+ debug.log(f"{model_name} is on meta device - skipping movement (will materialize when needed)",
933
+ category=model_name.lower())
934
+ return False
935
+
936
+ # Determine reason for movement
937
+ reason = reason or "inference requirement"
938
+
939
+ # Log the movement with full device strings
940
+ if debug:
941
+ current_device_str = _device_str(current_device)
942
+ target_device_str = _device_str(target_device)
943
+ debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general")
944
+
945
+ # Start timer based on direction
946
+ timer_name = f"{model_name.lower()}_to_{'gpu' if target_type != 'cpu' else 'cpu'}"
947
+ if debug:
948
+ debug.start_timer(timer_name)
949
+
950
+ # Move model and clear gradients
951
+ model.to(target_device)
952
+ model.zero_grad(set_to_none=True)
953
+
954
+ # Clear VAE memory buffers when moving to CPU
955
+ if target_type == 'cpu' and model_name == "VAE":
956
+ cleared_count = 0
957
+ for module in model.modules():
958
+ if hasattr(module, 'memory') and module.memory is not None:
959
+ if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps):
960
+ module.memory = None
961
+ cleared_count += 1
962
+ if cleared_count > 0 and debug:
963
+ debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success")
964
+
965
+ # End timer
966
+ if debug:
967
+ debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}")
968
+
969
+ return True
970
+
971
+
972
+ def clear_runtime_caches(runner: Any, debug: Optional['Debug'] = None) -> int:
973
+ """
974
+ Clear all runtime caches and temporary attributes.
975
+ """
976
+ if not runner:
977
+ return 0
978
+
979
+ if debug:
980
+ debug.start_timer("runtime_cache_clear")
981
+
982
+ cleaned_items = 0
983
+
984
+ # 1. Clear main runner cache
985
+ if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
986
+ if debug:
987
+ debug.start_timer("runner_cache_clear")
988
+
989
+ cache_entries = len(runner.cache.cache)
990
+
991
+ # Properly release tensor memory and delete as we go
992
+ for key in list(runner.cache.cache.keys()):
993
+ value = runner.cache.cache[key]
994
+ if torch.is_tensor(value):
995
+ release_tensor_memory(value)
996
+ elif isinstance(value, (list, tuple)):
997
+ for item in value:
998
+ if torch.is_tensor(item):
999
+ release_tensor_memory(item)
1000
+ # Delete immediately to release reference
1001
+ del runner.cache.cache[key]
1002
+
1003
+ # Final clear for safety
1004
+ runner.cache.cache.clear()
1005
+ cleaned_items += cache_entries
1006
+
1007
+ if debug:
1008
+ debug.end_timer("runner_cache_clear", f"Clearing main runner cache entries")
1009
+
1010
+ if cache_entries > 0:
1011
+ debug.log(f"Cleared {cache_entries} runtime cache entries", category="success")
1012
+
1013
+ # 2. Clear RoPE caches
1014
+ if hasattr(runner, 'dit'):
1015
+ if debug:
1016
+ debug.start_timer("rope_cache_clear")
1017
+
1018
+ model = runner.dit
1019
+ if hasattr(model, 'dit_model'): # Handle wrapper
1020
+ model = model.dit_model
1021
+
1022
+ rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
1023
+ cleaned_items += rope_cleared
1024
+ if debug:
1025
+ debug.end_timer("rope_cache_clear", "Clearing RoPE LRU caches")
1026
+
1027
+ if rope_cleared > 0:
1028
+ debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
1029
+
1030
+ # 3. Clear temporary attributes
1031
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
1032
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
1033
+
1034
+ for obj in [runner, getattr(runner, 'dit', None), getattr(runner, 'vae', None)]:
1035
+ if obj is None:
1036
+ continue
1037
+
1038
+ actual_obj = obj.dit_model if hasattr(obj, 'dit_model') else obj
1039
+
1040
+ for attr in temp_attrs:
1041
+ if hasattr(actual_obj, attr):
1042
+ delattr(actual_obj, attr)
1043
+ cleaned_items += 1
1044
+
1045
+ if debug:
1046
+ debug.end_timer("runtime_cache_clear", f"clear_runtime_caches() completion")
1047
+
1048
+ return cleaned_items
1049
+
1050
+
1051
+ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
1052
+ """
1053
+ Cleanup DiT model and BlockSwap state after upscaling phase.
1054
+ Called at the end of upscale_all_batches when DiT is no longer needed.
1055
+
1056
+ Args:
1057
+ runner: Runner instance containing DiT model
1058
+ debug: Debug instance for logging
1059
+ cache_model: If True, move DiT to offload_device; if False, delete completely
1060
+ """
1061
+ if not runner or not hasattr(runner, 'dit'):
1062
+ return
1063
+
1064
+ if debug:
1065
+ debug.log("Cleaning up DiT components", category="cleanup")
1066
+
1067
+ # 1. Clear DiT-specific runtime caches first
1068
+ if hasattr(runner, 'dit'):
1069
+ model = runner.dit
1070
+ if hasattr(model, 'dit_model'): # Handle wrapper
1071
+ model = model.dit_model
1072
+
1073
+ # Clear RoPE caches
1074
+ rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
1075
+ if rope_cleared > 0 and debug:
1076
+ debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
1077
+
1078
+ # Clear DiT temporary attributes
1079
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
1080
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
1081
+
1082
+ actual_obj = model.dit_model if hasattr(model, 'dit_model') else model
1083
+ for attr in temp_attrs:
1084
+ if hasattr(actual_obj, attr):
1085
+ delattr(actual_obj, attr)
1086
+
1087
+ # 2. Handle model offloading (for caching or before deletion)
1088
+ try:
1089
+ param_device = next(runner.dit.parameters()).device
1090
+
1091
+ # Move model off GPU if needed
1092
+ if param_device.type not in ['meta', 'cpu']:
1093
+ # MPS: skip CPU movement before deletion (unified memory, just causes sync)
1094
+ if param_device.type == 'mps' and not cache_model:
1095
+ if debug:
1096
+ debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup")
1097
+ else:
1098
+ offload_target = getattr(runner, '_dit_offload_device', None)
1099
+ if offload_target is None or offload_target == 'none':
1100
+ offload_target = torch.device('cpu')
1101
+ reason = "model caching" if cache_model else "releasing GPU memory"
1102
+ manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
1103
+ debug=debug, reason=reason, runner=runner)
1104
+ elif param_device.type == 'meta' and debug:
1105
+ debug.log("DiT on meta device - keeping structure for cache", category="cleanup")
1106
+ except StopIteration:
1107
+ pass
1108
+
1109
+ # 3. Clean BlockSwap after model movement
1110
+ if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
1111
+ # Import here to avoid circular dependency
1112
+ from .blockswap import cleanup_blockswap
1113
+
1114
+ # If model had a swap stream, synchronize before cleanup to avoid races
1115
+ try:
1116
+ model_for_sync = runner.dit.dit_model if hasattr(runner.dit, 'dit_model') else runner.dit
1117
+ if hasattr(model_for_sync, "_swap_stream"):
1118
+ try:
1119
+ model_for_sync._swap_stream.synchronize()
1120
+ except Exception:
1121
+ if debug:
1122
+ debug.log("Failed to synchronize model._swap_stream before cleanup_blockswap", level="WARNING", category="cleanup", force=True)
1123
+ except Exception:
1124
+ pass
1125
+
1126
+ cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model)
1127
+
1128
+
1129
+ # 4. Complete cleanup if not caching
1130
+ if not cache_model:
1131
+ release_model_memory(model=runner.dit, debug=debug)
1132
+ runner.dit = None
1133
+ if debug:
1134
+ debug.log("DiT model deleted", category="cleanup")
1135
+
1136
+ # Clear DiT config attributes - not needed when model is not cached (will be recreated)
1137
+ if hasattr(runner, '_dit_compile_args'):
1138
+ delattr(runner, '_dit_compile_args')
1139
+ if hasattr(runner, '_dit_block_swap_config'):
1140
+ delattr(runner, '_dit_block_swap_config')
1141
+ if hasattr(runner, '_dit_attention_mode'):
1142
+ delattr(runner, '_dit_attention_mode')
1143
+
1144
+ # 5. Clear DiT temporary attributes (should be already cleared in materialize_model)
1145
+ runner._dit_checkpoint = None
1146
+ runner._dit_dtype_override = None
1147
+
1148
+ # 6. Clear DiT-related components and temporary attributes
1149
+ runner.sampler = None
1150
+ runner.sampling_timesteps = None
1151
+ runner.schedule = None
1152
+
1153
+
1154
+ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
1155
+ """
1156
+ Cleanup VAE model after decoding phase.
1157
+ Called at the end of decode_all_batches when VAE is no longer needed.
1158
+
1159
+ Args:
1160
+ runner: Runner instance containing VAE model
1161
+ debug: Debug instance for logging
1162
+ cache_model: If True, move VAE to offload_device; if False, delete completely
1163
+ """
1164
+ if not runner or not hasattr(runner, 'vae'):
1165
+ return
1166
+
1167
+ if debug:
1168
+ debug.log("Cleaning up VAE components", category="cleanup")
1169
+
1170
+ # 1. Clear VAE-specific temporary attributes
1171
+ if hasattr(runner, 'vae'):
1172
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
1173
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
1174
+
1175
+ for attr in temp_attrs:
1176
+ if hasattr(runner.vae, attr):
1177
+ delattr(runner.vae, attr)
1178
+
1179
+ # 2. Handle model offloading (for caching or before deletion)
1180
+ try:
1181
+ param_device = next(runner.vae.parameters()).device
1182
+
1183
+ # Move model off GPU if needed
1184
+ if param_device.type not in ['meta', 'cpu']:
1185
+ # MPS: skip CPU movement before deletion (unified memory, just causes sync)
1186
+ if param_device.type == 'mps' and not cache_model:
1187
+ if debug:
1188
+ debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup")
1189
+ else:
1190
+ offload_target = getattr(runner, '_vae_offload_device', None)
1191
+ if offload_target is None or offload_target == 'none':
1192
+ offload_target = torch.device('cpu')
1193
+ reason = "model caching" if cache_model else "releasing GPU memory"
1194
+ manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
1195
+ debug=debug, reason=reason, runner=runner)
1196
+ elif param_device.type == 'meta' and debug:
1197
+ debug.log("VAE on meta device - keeping structure for cache", category="cleanup")
1198
+ except StopIteration:
1199
+ pass
1200
+
1201
+ # 3. Complete cleanup if not caching
1202
+ if not cache_model:
1203
+ release_model_memory(model=runner.vae, debug=debug)
1204
+ runner.vae = None
1205
+ if debug:
1206
+ debug.log("VAE model deleted", category="cleanup")
1207
+
1208
+ # Clear VAE config attributes - not needed when model is not cached (will be recreated)
1209
+ if hasattr(runner, '_vae_compile_args'):
1210
+ delattr(runner, '_vae_compile_args')
1211
+ if hasattr(runner, '_vae_tiling_config'):
1212
+ delattr(runner, '_vae_tiling_config')
1213
+
1214
+ # 3. Clear VAE temporary attributes (should be already cleared in materialize_model)
1215
+ runner._vae_checkpoint = None
1216
+ runner._vae_dtype_override = None
1217
+
1218
+
1219
+ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bool = False, vae_cache: bool = False) -> None:
1220
+ """
1221
+ Complete cleanup of runner and remaining components with independent model caching support.
1222
+ This is a lightweight cleanup for final stage, as model-specific cleanup
1223
+ happens in their respective phases (cleanup_dit, cleanup_vae).
1224
+
1225
+ Args:
1226
+ runner: Runner instance to clean up
1227
+ debug: Debug instance for logging
1228
+ dit_cache: If True, preserve DiT model on offload_device for future runs
1229
+ vae_cache: If True, preserve VAE model on offload_device for future runs
1230
+
1231
+ Behavior:
1232
+ - Can cache DiT and VAE independently for flexible memory management
1233
+ - Preserves _dit_model_name and _vae_model_name when either model is cached for change detection
1234
+ - Clears all temporary attributes and runtime caches
1235
+ - Performs deep memory cleanup only when both models are fully released
1236
+
1237
+ Note:
1238
+ Model name tracking (_dit_model_name, _vae_model_name) is only cleared if neither
1239
+ model is cached, enabling proper model change detection on subsequent runs.
1240
+ """
1241
+ if not runner:
1242
+ return
1243
+
1244
+ if debug:
1245
+ cleanup_type = "partial cleanup" if (dit_cache or vae_cache) else "full cleanup"
1246
+ debug.log(f"Starting {cleanup_type}", category="cleanup")
1247
+
1248
+ # 1. Cleanup any remaining models if they still exist
1249
+ # (This handles cases where phases were skipped or errored)
1250
+ if hasattr(runner, 'dit') and runner.dit is not None:
1251
+ cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache)
1252
+
1253
+ if hasattr(runner, 'vae') and runner.vae is not None:
1254
+ cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache)
1255
+
1256
+ # 2. Clear remaining runtime caches
1257
+ clear_runtime_caches(runner=runner, debug=debug)
1258
+
1259
+ # 3. Clear config and other non-model components when fully releasing runner
1260
+ if not (dit_cache or vae_cache):
1261
+ # Full cleanup - clear config and model tracking
1262
+ runner.config = None
1263
+ runner._dit_model_name = None
1264
+ runner._vae_model_name = None
1265
+
1266
+ # 4. Final memory cleanup
1267
+ clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup")
1268
+
1269
+ # 5. Clear cuBLAS workspaces
1270
+ torch._C._cuda_clearCublasWorkspaces() if hasattr(torch._C, '_cuda_clearCublasWorkspaces') else None
1271
+
1272
+ # Log what models are cached for next run
1273
+ if dit_cache or vae_cache:
1274
+ cached_models = []
1275
+ if dit_cache and hasattr(runner, '_dit_model_name'):
1276
+ cached_models.append(f"DiT ({runner._dit_model_name})")
1277
+ if vae_cache and hasattr(runner, '_vae_model_name'):
1278
+ cached_models.append(f"VAE ({runner._vae_model_name})")
1279
+
1280
+ if cached_models:
1281
+ models_str = " and ".join(cached_models)
1282
+ debug.log(f"Models cached for next run: {models_str}", category="cache", force=True)
1283
+
1284
+ if debug:
1285
+ debug.log(f"Completed {cleanup_type}", category="success")
src/optimization/memory_manager.py.bak ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory management module for SeedVR2
3
+ Handles VRAM usage, cache management, and memory optimization
4
+
5
+ Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
6
+ """
7
+
8
+ import torch
9
+ import gc
10
+ import sys
11
+ import time
12
+ import psutil
13
+ import platform
14
+ from typing import Tuple, Dict, Any, Optional, List, Union
15
+
16
+
17
+ def _device_str(device: Union[torch.device, str]) -> str:
18
+ """Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
19
+ s = str(device).upper()
20
+ return 'MPS' if s.startswith('MPS') else s
21
+
22
+
23
+ def is_mps_available() -> bool:
24
+ """Check if MPS (Apple Metal) backend is available."""
25
+ return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
26
+
27
+
28
+ def is_cuda_available() -> bool:
29
+ """Check if CUDA backend is available."""
30
+ return torch.cuda.is_available()
31
+
32
+
33
+ def get_gpu_backend() -> str:
34
+ """Get the active GPU backend type.
35
+
36
+ Returns:
37
+ 'cuda': NVIDIA CUDA
38
+ 'mps': Apple Metal Performance Shaders
39
+ 'cpu': No GPU backend available
40
+ """
41
+ if is_cuda_available():
42
+ return 'cuda'
43
+ if is_mps_available():
44
+ return 'mps'
45
+ return 'cpu'
46
+
47
+
48
+ def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]:
49
+ """
50
+ Get list of available compute devices for SeedVR2
51
+
52
+ Args:
53
+ include_none: If True, prepend "none" to the device list (for offload options)
54
+ include_cpu: If True, include "cpu" in the device list (for offload options only)
55
+ Note: On MPS-only systems, "cpu" is automatically excluded since
56
+ unified memory architecture makes CPU offloading meaningless
57
+
58
+ Returns:
59
+ List of device strings (e.g., ["cuda:0", "cuda:1"] or ["none", "cpu", "cuda:0", "cuda:1"])
60
+ """
61
+ devs = []
62
+ has_cuda = False
63
+ has_mps = False
64
+
65
+ try:
66
+ if is_cuda_available():
67
+ devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
68
+ has_cuda = True
69
+ except Exception:
70
+ pass
71
+
72
+ try:
73
+ if is_mps_available():
74
+ devs.append("mps") # MPS doesn't use device indices
75
+ has_mps = True
76
+ except Exception:
77
+ pass
78
+
79
+ # Build result list with optional prefixes
80
+ result = []
81
+ if include_none:
82
+ result.append("none")
83
+
84
+ # Only include "cpu" option if:
85
+ # 1. It was requested (include_cpu=True), AND
86
+ # 2. Either CUDA is available OR MPS is not the only option
87
+ # Rationale: On MPS-only systems with unified memory architecture,
88
+ # CPU offloading is semantically meaningless as CPU and GPU share the same memory pool
89
+ if include_cpu and (has_cuda or not has_mps):
90
+ result.append("cpu")
91
+
92
+ result.extend(devs)
93
+
94
+ return result if result else []
95
+
96
+
97
+ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]:
98
+ """
99
+ Get basic VRAM availability info (free and total memory).
100
+ Used for capacity planning and initial checks.
101
+
102
+ Args:
103
+ device: Optional device to query. If None, uses cuda:0
104
+
105
+ Returns:
106
+ dict: {"free_gb": float, "total_gb": float} or {"error": str}
107
+ """
108
+ try:
109
+ if is_cuda_available():
110
+ if device is None:
111
+ device = torch.device("cuda:0")
112
+ elif not isinstance(device, torch.device):
113
+ device = torch.device(device)
114
+ free_memory, total_memory = torch.cuda.mem_get_info(device)
115
+ elif is_mps_available():
116
+ # MPS doesn't support per-device queries or mem_get_info
117
+ # Use system memory as proxy
118
+ mem = psutil.virtual_memory()
119
+ free_memory = mem.total - mem.used
120
+ total_memory = mem.total
121
+ else:
122
+ return {"error": "No GPU backend available (CUDA/MPS)"}
123
+
124
+ return {
125
+ "free_gb": free_memory / (1024**3),
126
+ "total_gb": total_memory / (1024**3)
127
+ }
128
+ except Exception as e:
129
+ return {"error": f"Failed to get memory info: {str(e)}"}
130
+
131
+
132
+ # Initial VRAM check at module load
133
+ vram_info = get_basic_vram_info(device=None)
134
+ if "error" not in vram_info:
135
+ backend = "MPS" if is_mps_available() else "CUDA"
136
+ print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total")
137
+ else:
138
+ print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
139
+
140
+
141
+ def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
142
+ """
143
+ Get current VRAM usage metrics for monitoring.
144
+ Used for tracking memory consumption during processing.
145
+
146
+ Args:
147
+ device: Optional device to query. If None, uses cuda:0
148
+ debug: Optional debug instance for logging
149
+
150
+ Returns:
151
+ tuple: (allocated_gb, reserved_gb, peak_allocated_gb, peak_reserved_gb)
152
+ Returns (0, 0, 0, 0) if no GPU available
153
+ """
154
+ try:
155
+ if is_cuda_available():
156
+ if device is None:
157
+ device = torch.device("cuda:0")
158
+ elif not isinstance(device, torch.device):
159
+ device = torch.device(device)
160
+ allocated = torch.cuda.memory_allocated(device) / (1024**3)
161
+ reserved = torch.cuda.memory_reserved(device) / (1024**3)
162
+ peak_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
163
+ peak_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
164
+ return allocated, reserved, peak_allocated, peak_reserved
165
+ elif is_mps_available():
166
+ # MPS doesn't support per-device queries - uses global memory tracking
167
+ allocated = torch.mps.current_allocated_memory() / (1024**3)
168
+ reserved = torch.mps.driver_allocated_memory() / (1024**3)
169
+ # MPS doesn't track peak separately
170
+ return allocated, reserved, allocated, reserved
171
+ except Exception as e:
172
+ if debug:
173
+ debug.log(f"Failed to get VRAM usage: {e}", level="WARNING", category="memory", force=True)
174
+ return 0.0, 0.0, 0.0, 0.0
175
+
176
+
177
+ def get_ram_usage(debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
178
+ """
179
+ Get current RAM usage metrics for the current process.
180
+ Provides accurate tracking of process-specific memory consumption.
181
+
182
+ Args:
183
+ debug: Optional debug instance for logging
184
+
185
+ Returns:
186
+ tuple: (process_gb, available_gb, total_gb, used_by_others_gb)
187
+ Returns (0, 0, 0, 0) if psutil not available or on error
188
+ """
189
+ try:
190
+ if not psutil:
191
+ return 0.0, 0.0, 0.0, 0.0
192
+
193
+ # Get current process memory
194
+ process = psutil.Process()
195
+ process_memory = process.memory_info()
196
+ process_gb = process_memory.rss / (1024**3)
197
+
198
+ # Get system memory
199
+ sys_memory = psutil.virtual_memory()
200
+ total_gb = sys_memory.total / (1024**3)
201
+ available_gb = sys_memory.available / (1024**3)
202
+
203
+ # Calculate memory used by other processes
204
+ # This is the CORRECT calculation:
205
+ total_used_gb = total_gb - available_gb # Total memory used by ALL processes
206
+ used_by_others_gb = max(0, total_used_gb - process_gb) # Subtract current process
207
+
208
+ return process_gb, available_gb, total_gb, used_by_others_gb
209
+
210
+ except Exception as e:
211
+ if debug:
212
+ debug.log(f"Failed to get RAM usage: {e}", level="WARNING", category="memory", force=True)
213
+ return 0.0, 0.0, 0.0, 0.0
214
+
215
+
216
+ # Global cache for OS libraries (initialized once)
217
+ _os_memory_lib = None
218
+
219
+
220
+ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: bool = True,
221
+ timer_name: Optional[str] = None) -> None:
222
+ """
223
+ Clear memory caches with two-tier approach for optimal performance.
224
+
225
+ Args:
226
+ debug: Debug instance for logging (optional)
227
+ force: If True, always clear. If False, only clear when <5% free
228
+ deep: If True, perform deep cleanup including GC and OS operations.
229
+ If False (default), only perform minimal GPU cache clearing.
230
+ timer_name: Optional suffix for timer names to make them unique per invocation
231
+
232
+ Two-tier approach:
233
+ - Minimal mode (deep=False): GPU cache operations (~1-5ms)
234
+ Used for frequent calls during batch processing
235
+ - Deep mode (deep=True): Complete cleanup with GC and OS operations (~10-50ms)
236
+ Used at key points like model switches or final cleanup
237
+ """
238
+ global _os_memory_lib
239
+
240
+ # Create unique timer names if suffix provided
241
+ if timer_name:
242
+ main_timer = f"memory_clear_{timer_name}"
243
+ gpu_timer = f"gpu_cache_clear_{timer_name}"
244
+ gc_timer = f"garbage_collection_{timer_name}"
245
+ os_timer = f"os_memory_release_{timer_name}"
246
+ completion_msg = f"clear_memory() completion ({timer_name})"
247
+ else:
248
+ main_timer = "memory_clear"
249
+ gpu_timer = "gpu_cache_clear"
250
+ gc_timer = "garbage_collection"
251
+ os_timer = "os_memory_release"
252
+ completion_msg = "clear_memory() completion"
253
+
254
+ # Start timer for entire operation
255
+ if debug:
256
+ debug.start_timer(main_timer)
257
+
258
+ # Check if we should clear based on memory pressure
259
+ if not force:
260
+ should_clear = False
261
+
262
+ # Use existing function for memory info
263
+ mem_info = get_basic_vram_info(device=None)
264
+
265
+ if "error" not in mem_info and mem_info["total_gb"] > 0:
266
+ # Check VRAM/MPS memory pressure (5% free threshold)
267
+ free_ratio = mem_info["free_gb"] / mem_info["total_gb"]
268
+ if free_ratio < 0.05:
269
+ should_clear = True
270
+ if debug:
271
+ backend = "Unified Memory" if is_mps_available() else "VRAM"
272
+ debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory")
273
+
274
+ # For non-MPS systems, also check system RAM separately
275
+ if not should_clear and not is_mps_available():
276
+ mem = psutil.virtual_memory()
277
+ if mem.available < mem.total * 0.05:
278
+ should_clear = True
279
+ if debug:
280
+ debug.log(f"RAM pressure: {mem.available/(1024**3):.2f}GB free of {mem.total/(1024**3):.2f}GB", category="memory")
281
+
282
+ if not should_clear:
283
+ # End timer before early return to keep stack clean
284
+ if debug:
285
+ debug.end_timer(main_timer)
286
+ return
287
+
288
+ # Determine cleanup level
289
+ cleanup_mode = "deep" if deep else "minimal"
290
+ if debug:
291
+ debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup")
292
+
293
+ # ===== MINIMAL OPERATIONS (Always performed) =====
294
+ # Step 1: Clear GPU caches - Fast operations (~1-5ms)
295
+ if debug:
296
+ debug.start_timer(gpu_timer)
297
+
298
+ if is_cuda_available():
299
+ torch.cuda.empty_cache()
300
+ torch.cuda.ipc_collect()
301
+ elif is_mps_available():
302
+ torch.mps.empty_cache()
303
+
304
+ if debug:
305
+ debug.end_timer(gpu_timer, "GPU cache clearing")
306
+
307
+ # ===== DEEP OPERATIONS (Only when deep=True) =====
308
+ if deep:
309
+ # Step 2: Deep garbage collection (expensive ~5-20ms)
310
+ if debug:
311
+ debug.start_timer(gc_timer)
312
+
313
+ gc.collect(2)
314
+
315
+ if debug:
316
+ debug.end_timer(gc_timer, "Garbage collection")
317
+
318
+ # Step 3: Return memory to OS (platform-specific, ~5-30ms)
319
+ if debug:
320
+ debug.start_timer(os_timer)
321
+
322
+ try:
323
+ if sys.platform == 'linux':
324
+ # Linux: malloc_trim
325
+ import ctypes # Import only when needed
326
+ if _os_memory_lib is None:
327
+ _os_memory_lib = ctypes.CDLL("libc.so.6")
328
+ _os_memory_lib.malloc_trim(0)
329
+
330
+ elif sys.platform == 'win32':
331
+ # Windows: Trim working set
332
+ import ctypes # Import only when needed
333
+ if _os_memory_lib is None:
334
+ _os_memory_lib = ctypes.windll.kernel32
335
+ handle = _os_memory_lib.GetCurrentProcess()
336
+ _os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1)
337
+
338
+ elif is_mps_available():
339
+ # macOS with MPS
340
+ import ctypes # Import only when needed
341
+ import ctypes.util
342
+ if _os_memory_lib is None:
343
+ libc_path = ctypes.util.find_library('c')
344
+ if libc_path:
345
+ _os_memory_lib = ctypes.CDLL(libc_path)
346
+
347
+ if _os_memory_lib:
348
+ _os_memory_lib.sync()
349
+ except Exception as e:
350
+ if debug:
351
+ debug.log(f"Failed to perform OS memory operations: {e}", level="WARNING", category="memory", force=True)
352
+
353
+ if debug:
354
+ debug.end_timer(os_timer, "OS memory release")
355
+
356
+ # End overall timer
357
+ if debug:
358
+ debug.end_timer(main_timer, completion_msg)
359
+
360
+
361
+ def retry_on_oom(func, *args, debug=None, operation_name="operation", **kwargs):
362
+ """
363
+ Execute function with single OOM retry after memory cleanup.
364
+
365
+ Args:
366
+ func: Callable to execute
367
+ *args: Positional arguments for func
368
+ debug: Debug instance for logging (optional)
369
+ operation_name: Name for logging
370
+ **kwargs: Keyword arguments for func
371
+
372
+ Returns:
373
+ Result of func(*args, **kwargs)
374
+ """
375
+ try:
376
+ return func(*args, **kwargs)
377
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
378
+ # Only handle OOM errors
379
+ if not any(x in str(e).lower() for x in ["out of memory", "allocation on device"]):
380
+ raise
381
+
382
+ if debug:
383
+ debug.log(f"OOM during {operation_name}: {e}", level="WARNING", category="memory", force=True)
384
+ debug.log(f"Clearing memory and retrying", category="info", force=True)
385
+
386
+ # Clear memory
387
+ clear_memory(debug=debug, deep=True, force=True, timer_name=operation_name)
388
+ # Let memory settle
389
+ time.sleep(0.5)
390
+ debug.log_memory_state("After memory clearing", show_tensors=False, detailed_tensors=False)
391
+
392
+ # Single retry
393
+ try:
394
+ result = func(*args, **kwargs)
395
+ if debug:
396
+ debug.log(f"Retry successful for {operation_name}", category="success", force=True)
397
+ return result
398
+ except Exception as retry_e:
399
+ if debug:
400
+ debug.log(f"Retry failed for {operation_name}: {retry_e}", level="ERROR", category="memory", force=True)
401
+ raise
402
+
403
+
404
+ def reset_vram_peak(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> None:
405
+ """
406
+ Reset VRAM peak memory statistics for fresh tracking.
407
+
408
+ Args:
409
+ device: Optional device to reset stats for. If None, uses cuda:0
410
+ debug: Optional debug instance for logging
411
+ """
412
+ if debug and debug.enabled:
413
+ debug.log("Resetting VRAM peak memory statistics", category="memory")
414
+ try:
415
+ if is_cuda_available():
416
+ if device is None:
417
+ device = torch.device("cuda:0")
418
+ elif not isinstance(device, torch.device):
419
+ device = torch.device(device)
420
+ torch.cuda.reset_peak_memory_stats(device)
421
+ # Note: MPS doesn't support peak memory reset - no action needed
422
+ except Exception as e:
423
+ if debug and debug.enabled:
424
+ debug.log(f"Failed to reset peak memory stats: {e}", level="WARNING", category="memory", force=True)
425
+
426
+
427
+ def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> int:
428
+ """
429
+ Clear ALL LRU caches from RoPE modules.
430
+
431
+ Args:
432
+ model: PyTorch model to clear caches from
433
+ debug: Optional debug instance for logging
434
+
435
+ Returns:
436
+ Number of caches cleared
437
+ """
438
+ if model is None:
439
+ return 0
440
+
441
+ cleared_count = 0
442
+ try:
443
+ for name, module in model.named_modules():
444
+ if hasattr(module, 'get_axial_freqs') and hasattr(module.get_axial_freqs, 'cache_clear'):
445
+ try:
446
+ module.get_axial_freqs.cache_clear()
447
+ cleared_count += 1
448
+ except Exception as e:
449
+ if debug:
450
+ debug.log(f"Failed to clear RoPE LRU cache for module {name}: {e}", level="WARNING", category="memory", force=True)
451
+ except (AttributeError, RuntimeError) as e:
452
+ if debug:
453
+ debug.log(f"Failed to iterate model modules for RoPE LRU cache clearing: {e}", level="WARNING", category="memory", force=True)
454
+
455
+ return cleared_count
456
+
457
+
458
+ def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None:
459
+ """Release tensor memory from any device (CPU/CUDA/MPS)"""
460
+ if tensor is not None and torch.is_tensor(tensor):
461
+ # Release storage for all devices (CPU, CUDA, MPS)
462
+ if tensor.numel() > 0:
463
+ tensor.data.set_()
464
+ tensor.grad = None
465
+
466
+
467
+ def release_tensor_collection(collection: Any, recursive: bool = True) -> None:
468
+ """
469
+ Release GPU memory from tensors in any collection (list, tuple, dict, or single tensor).
470
+
471
+ Args:
472
+ collection: Tensor, list, tuple, dict, or nested structure to release
473
+ recursive: If True, handle nested structures recursively
474
+
475
+ Examples:
476
+ release_tensor_collection(tensor) # Single tensor
477
+ release_tensor_collection([tensor1, tensor2]) # List of tensors
478
+ release_tensor_collection([[t1, t2], [t3, t4]]) # Nested lists
479
+ release_tensor_collection({'a': tensor}) # Dict values
480
+ """
481
+ if collection is None:
482
+ return
483
+
484
+ if torch.is_tensor(collection):
485
+ release_tensor_memory(collection)
486
+ elif isinstance(collection, dict):
487
+ for value in collection.values():
488
+ if recursive:
489
+ release_tensor_collection(value, recursive=True)
490
+ elif torch.is_tensor(value):
491
+ release_tensor_memory(value)
492
+ elif isinstance(collection, (list, tuple)):
493
+ for item in collection:
494
+ if recursive:
495
+ release_tensor_collection(item, recursive=True)
496
+ elif torch.is_tensor(item):
497
+ release_tensor_memory(item)
498
+
499
+
500
+ def release_text_embeddings(*embeddings: torch.Tensor, debug: Optional['Debug'] = None, names: Optional[List[str]] = None) -> None:
501
+ """
502
+ Release memory for text embeddings
503
+
504
+ Args:
505
+ *embeddings: Variable number of embedding tensors to release
506
+ debug: Optional debug instance for logging
507
+ names: Optional list of names for logging
508
+ """
509
+ for i, embedding in enumerate(embeddings):
510
+ if embedding is not None:
511
+ release_tensor_memory(embedding)
512
+ if debug and names and i < len(names):
513
+ debug.log(f"Cleaned up {names[i]}", category="cleanup")
514
+
515
+
516
+ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None) -> None:
517
+ """
518
+ Clean up text embeddings from a context dictionary.
519
+ Extracts embeddings, releases memory, and clears the context entry.
520
+
521
+ Args:
522
+ ctx: Context dictionary potentially containing 'text_embeds'
523
+ debug: Optional debug instance for logging
524
+ """
525
+ if not ctx or not ctx.get('text_embeds'):
526
+ return
527
+
528
+ embeddings = []
529
+ names = []
530
+ for key, embeds_list in ctx['text_embeds'].items():
531
+ if embeds_list:
532
+ embeddings.extend(embeds_list)
533
+ names.append(key)
534
+
535
+ if embeddings:
536
+ release_text_embeddings(embeddings, names, debug)
537
+
538
+ if debug:
539
+ debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup")
540
+
541
+ ctx['text_embeds'] = None
542
+
543
+
544
+ def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None:
545
+ """
546
+ Release all GPU/MPS memory from model in-place without CPU transfer.
547
+
548
+ Args:
549
+ model: PyTorch model to release memory from
550
+ debug: Optional debug instance for logging
551
+ """
552
+ if model is None:
553
+ return
554
+
555
+ try:
556
+ # Clear gradients first
557
+ model.zero_grad(set_to_none=True)
558
+
559
+ # Release GPU memory directly without CPU transfer
560
+ released_params = 0
561
+ released_buffers = 0
562
+
563
+ for param in model.parameters():
564
+ if param.is_cuda or param.is_mps:
565
+ if param.numel() > 0:
566
+ param.data.set_()
567
+ released_params += 1
568
+ param.grad = None
569
+
570
+ for buffer in model.buffers():
571
+ if buffer.is_cuda or buffer.is_mps:
572
+ if buffer.numel() > 0:
573
+ buffer.data.set_()
574
+ released_buffers += 1
575
+
576
+ if debug and (released_params > 0 or released_buffers > 0):
577
+ debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success")
578
+
579
+ except (AttributeError, RuntimeError) as e:
580
+ if debug:
581
+ debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True)
582
+
583
+
584
+ def manage_tensor(
585
+ tensor: torch.Tensor,
586
+ target_device: torch.device,
587
+ tensor_name: str = "tensor",
588
+ dtype: Optional[torch.dtype] = None,
589
+ non_blocking: bool = False,
590
+ debug: Optional['Debug'] = None,
591
+ reason: Optional[str] = None,
592
+ indent_level: int = 0
593
+ ) -> torch.Tensor:
594
+ """
595
+ Unified tensor management for device movement and dtype conversion.
596
+
597
+ Handles both device transfers (CPU ↔ GPU) and dtype conversions (e.g., float16 → bfloat16)
598
+ with intelligent early-exit optimization and comprehensive logging.
599
+
600
+ Args:
601
+ tensor: Tensor to manage
602
+ target_device: Target device (torch.device object)
603
+ tensor_name: Descriptive name for logging (e.g., "latent", "sample", "alpha_channel")
604
+ dtype: Optional target dtype to cast to (if None, keeps original dtype)
605
+ non_blocking: Whether to use non-blocking transfer
606
+ debug: Debug instance for logging
607
+ reason: Optional reason for the operation (e.g., "inference", "offload", "dtype alignment")
608
+ indent_level: Indentation level for debug logging (0=no indent, 1=2 spaces, etc.)
609
+
610
+ Returns:
611
+ Tensor on target device with optional dtype conversion
612
+
613
+ Note:
614
+ - Skips operation if tensor already has target device and dtype (zero-copy)
615
+ - Uses PyTorch's optimized .to() for efficient device/dtype handling
616
+ - Logs all operations consistently for tracking and debugging
617
+ """
618
+ if tensor is None:
619
+ return tensor
620
+
621
+ # Get current state
622
+ current_device = tensor.device
623
+ current_dtype = tensor.dtype
624
+ target_dtype = dtype if dtype is not None else current_dtype
625
+
626
+ # Check if movement is actually needed
627
+ needs_device_move = _device_str(current_device) != _device_str(target_device)
628
+ needs_dtype_change = dtype is not None and current_dtype != target_dtype
629
+
630
+ if not needs_device_move and not needs_dtype_change:
631
+ # Already on target device and dtype - skip
632
+ return tensor
633
+
634
+ # Determine reason for movement
635
+ if reason is None:
636
+ if needs_device_move and needs_dtype_change:
637
+ reason = "device and dtype conversion"
638
+ elif needs_device_move:
639
+ reason = "device movement"
640
+ else:
641
+ reason = "dtype conversion"
642
+
643
+ # Log the movement
644
+ if debug:
645
+ current_device_str = _device_str(current_device)
646
+ target_device_str = _device_str(target_device)
647
+
648
+ dtype_info = ""
649
+ if needs_dtype_change:
650
+ dtype_info = f", {current_dtype} → {target_dtype}"
651
+
652
+ debug.log(
653
+ f"Moving {tensor_name} from {current_device_str} to {target_device_str}{dtype_info} ({reason})",
654
+ category="general",
655
+ indent_level=indent_level
656
+ )
657
+
658
+ # Perform the operation based on what needs to change
659
+ if needs_device_move and needs_dtype_change:
660
+ # Both device and dtype need to change
661
+ return tensor.to(target_device, dtype=target_dtype, non_blocking=non_blocking)
662
+ elif needs_device_move:
663
+ # Only device needs to change
664
+ return tensor.to(target_device, non_blocking=non_blocking)
665
+ else:
666
+ # Only dtype needs to change
667
+ return tensor.to(dtype=target_dtype)
668
+
669
+
670
+ def manage_model_device(model: torch.nn.Module, target_device: torch.device, model_name: str,
671
+ debug: Optional['Debug'] = None, reason: Optional[str] = None,
672
+ runner: Optional[Any] = None) -> bool:
673
+ """
674
+ Move model to target device with optimizations.
675
+ Handles BlockSwap-enabled models transparently.
676
+
677
+ Args:
678
+ model: The model to move
679
+ target_device: Target device (torch.device object, e.g., torch.device('cuda:0'))
680
+ model_name: Name for logging (e.g., "VAE", "DiT")
681
+ debug: Debug instance for logging
682
+ reason: Optional custom reason for the movement
683
+ runner: Optional runner instance for BlockSwap detection
684
+
685
+ Returns:
686
+ bool: True if model was moved, False if already on target device
687
+ """
688
+ if model is None:
689
+ return False
690
+
691
+ # Check if this is a BlockSwap-enabled DiT model
692
+ is_blockswap_model = False
693
+ actual_model = model
694
+ if runner and model_name == "DiT":
695
+ # Import here to avoid circular dependency
696
+ from .blockswap import is_blockswap_enabled
697
+ # Check if BlockSwap config exists and is enabled
698
+ has_blockswap_config = (
699
+ hasattr(runner, '_dit_block_swap_config') and
700
+ is_blockswap_enabled(runner._dit_block_swap_config)
701
+ )
702
+
703
+ if has_blockswap_config:
704
+ is_blockswap_model = True
705
+ # Get the actual model (handle CompatibleDiT wrapper)
706
+ if hasattr(model, "dit_model"):
707
+ actual_model = model.dit_model
708
+
709
+ # Get current device
710
+ try:
711
+ current_device = next(model.parameters()).device
712
+ except StopIteration:
713
+ return False
714
+
715
+ # Extract device type for comparison (both are torch.device objects)
716
+ target_type = target_device.type
717
+ current_device_upper = _device_str(current_device)
718
+ target_device_upper = _device_str(target_device)
719
+
720
+ # Compare normalized device types
721
+ if current_device_upper == target_device_upper and not is_blockswap_model:
722
+ # Already on target device type, no movement needed
723
+ if debug:
724
+ debug.log(f"{model_name} already on {current_device_upper}, skipping movement", category="general")
725
+ return False
726
+
727
+ # Handle BlockSwap models specially
728
+ if is_blockswap_model:
729
+ return _handle_blockswap_model_movement(
730
+ runner, actual_model, current_device, target_device, target_type,
731
+ model_name, debug, reason
732
+ )
733
+
734
+ # Standard model movement (non-BlockSwap)
735
+ return _standard_model_movement(
736
+ model, current_device, target_device, target_type, model_name,
737
+ debug, reason
738
+ )
739
+
740
+
741
+ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
742
+ current_device: torch.device, target_device: torch.device,
743
+ target_type: str, model_name: str,
744
+ debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
745
+ """
746
+ Handle device movement for BlockSwap-enabled models.
747
+
748
+ Args:
749
+ runner: Runner instance with BlockSwap configuration
750
+ model: Model to move (actual unwrapped model)
751
+ current_device: Current device of the model
752
+ target_device: Target device (torch.device object)
753
+ target_type: Target device type (cpu/cuda/mps)
754
+ model_name: Model name for logging
755
+ debug: Debug instance
756
+ reason: Movement reason
757
+
758
+ Returns:
759
+ bool: True if model was moved
760
+ """
761
+ # Import here to avoid circular dependency
762
+ from .blockswap import set_blockswap_bypass
763
+
764
+ if target_type == "cpu":
765
+ # Moving to offload device (typically CPU)
766
+ # Check if any parameter is on GPU (for accurate logging)
767
+ actual_source_device = None
768
+ for param in model.parameters():
769
+ if param.device.type in ['cuda', 'mps']:
770
+ actual_source_device = param.device
771
+ break
772
+
773
+ source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device)
774
+
775
+ if debug:
776
+ debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general")
777
+
778
+ # Enable bypass to allow movement
779
+ set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
780
+
781
+ # Start timer
782
+ timer_name = f"{model_name.lower()}_to_{target_type}"
783
+ if debug:
784
+ debug.start_timer(timer_name)
785
+
786
+ # Move entire model to target offload device
787
+ model.to(target_device)
788
+ model.zero_grad(set_to_none=True)
789
+
790
+ if debug:
791
+ debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
792
+
793
+ return True
794
+
795
+ else:
796
+ # Moving to GPU (reload)
797
+ # Check if we're in bypass mode (coming from offload)
798
+ if not getattr(model, "_blockswap_bypass_protection", False):
799
+ # Not in bypass mode, blocks are already configured
800
+ if debug:
801
+ debug.log(f"{model_name} with BlockSwap active - blocks already distributed across devices, skipping movement", category="general")
802
+ return False
803
+
804
+ # Get actual current device for accurate logging
805
+ actual_current_device = None
806
+ for param in model.parameters():
807
+ if param.device.type != 'meta':
808
+ actual_current_device = param.device
809
+ break
810
+
811
+ current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD"
812
+
813
+ if debug:
814
+ debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general")
815
+
816
+ timer_name = f"{model_name.lower()}_to_gpu"
817
+ if debug:
818
+ debug.start_timer(timer_name)
819
+
820
+ # Restore blocks to their configured devices
821
+ if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"):
822
+ # Use configured offload_device from BlockSwap config
823
+ offload_device = model._block_swap_config.get("offload_device")
824
+ if not offload_device:
825
+ raise ValueError("BlockSwap config missing offload_device")
826
+
827
+ # Move blocks according to BlockSwap configuration
828
+ for b, block in enumerate(model.blocks):
829
+ if b > model.blocks_to_swap:
830
+ # This block should be on GPU
831
+ block.to(target_device)
832
+ else:
833
+ # This block stays on offload device (will be swapped during forward)
834
+ block.to(offload_device)
835
+
836
+ # Handle I/O components
837
+ if not model._block_swap_config.get("swap_io_components", False):
838
+ # I/O components should be on GPU if not offloaded
839
+ for name, module in model.named_children():
840
+ if name != "blocks":
841
+ module.to(target_device)
842
+ else:
843
+ # I/O components stay on offload device
844
+ for name, module in model.named_children():
845
+ if name != "blocks":
846
+ module.to(offload_device)
847
+
848
+ if debug:
849
+ # Get actual configuration from runner
850
+ if hasattr(model, '_block_swap_config'):
851
+ blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16)
852
+ total_blocks = model._block_swap_config.get('total_blocks', 32)
853
+ main_device = model._block_swap_config.get('main_device', 'GPU')
854
+ debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success")
855
+ else:
856
+ debug.log("BlockSwap blocks restored to configured devices", category="success")
857
+
858
+
859
+ # Reactivate BlockSwap now that blocks are restored to their configured devices
860
+ runner._blockswap_active = True
861
+
862
+ # Disable bypass, re-enable protection
863
+ set_blockswap_bypass(runner=runner, bypass=False, debug=debug)
864
+
865
+ if debug:
866
+ debug.end_timer(timer_name, "BlockSwap model restored")
867
+
868
+ return True
869
+
870
+
871
+ def _standard_model_movement(model: torch.nn.Module, current_device: torch.device,
872
+ target_device: torch.device, target_type: str, model_name: str,
873
+ debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
874
+ """
875
+ Handle standard (non-BlockSwap) model movement.
876
+
877
+ Args:
878
+ model: Model to move
879
+ current_device: Current device of the model
880
+ target_device: Target device (torch.device object)
881
+ target_type: Target device type
882
+ model_name: Model name for logging
883
+ debug: Debug instance
884
+ reason: Movement reason
885
+
886
+ Returns:
887
+ bool: True if model was moved
888
+ """
889
+ # Check if model is on meta device - can't move meta tensors
890
+ if current_device.type == 'meta':
891
+ if debug:
892
+ debug.log(f"{model_name} is on meta device - skipping movement (will materialize when needed)",
893
+ category=model_name.lower())
894
+ return False
895
+
896
+ # Determine reason for movement
897
+ reason = reason or "inference requirement"
898
+
899
+ # Log the movement with full device strings
900
+ if debug:
901
+ current_device_str = _device_str(current_device)
902
+ target_device_str = _device_str(target_device)
903
+ debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general")
904
+
905
+ # Start timer based on direction
906
+ timer_name = f"{model_name.lower()}_to_{'gpu' if target_type != 'cpu' else 'cpu'}"
907
+ if debug:
908
+ debug.start_timer(timer_name)
909
+
910
+ # Move model and clear gradients
911
+ model.to(target_device)
912
+ model.zero_grad(set_to_none=True)
913
+
914
+ # Clear VAE memory buffers when moving to CPU
915
+ if target_type == 'cpu' and model_name == "VAE":
916
+ cleared_count = 0
917
+ for module in model.modules():
918
+ if hasattr(module, 'memory') and module.memory is not None:
919
+ if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps):
920
+ module.memory = None
921
+ cleared_count += 1
922
+ if cleared_count > 0 and debug:
923
+ debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success")
924
+
925
+ # End timer
926
+ if debug:
927
+ debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}")
928
+
929
+ return True
930
+
931
+
932
+ def clear_runtime_caches(runner: Any, debug: Optional['Debug'] = None) -> int:
933
+ """
934
+ Clear all runtime caches and temporary attributes.
935
+ """
936
+ if not runner:
937
+ return 0
938
+
939
+ if debug:
940
+ debug.start_timer("runtime_cache_clear")
941
+
942
+ cleaned_items = 0
943
+
944
+ # 1. Clear main runner cache
945
+ if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
946
+ if debug:
947
+ debug.start_timer("runner_cache_clear")
948
+
949
+ cache_entries = len(runner.cache.cache)
950
+
951
+ # Properly release tensor memory and delete as we go
952
+ for key in list(runner.cache.cache.keys()):
953
+ value = runner.cache.cache[key]
954
+ if torch.is_tensor(value):
955
+ release_tensor_memory(value)
956
+ elif isinstance(value, (list, tuple)):
957
+ for item in value:
958
+ if torch.is_tensor(item):
959
+ release_tensor_memory(item)
960
+ # Delete immediately to release reference
961
+ del runner.cache.cache[key]
962
+
963
+ # Final clear for safety
964
+ runner.cache.cache.clear()
965
+ cleaned_items += cache_entries
966
+
967
+ if debug:
968
+ debug.end_timer("runner_cache_clear", f"Clearing main runner cache entries")
969
+
970
+ if cache_entries > 0:
971
+ debug.log(f"Cleared {cache_entries} runtime cache entries", category="success")
972
+
973
+ # 2. Clear RoPE caches
974
+ if hasattr(runner, 'dit'):
975
+ if debug:
976
+ debug.start_timer("rope_cache_clear")
977
+
978
+ model = runner.dit
979
+ if hasattr(model, 'dit_model'): # Handle wrapper
980
+ model = model.dit_model
981
+
982
+ rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
983
+ cleaned_items += rope_cleared
984
+ if debug:
985
+ debug.end_timer("rope_cache_clear", "Clearing RoPE LRU caches")
986
+
987
+ if rope_cleared > 0:
988
+ debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
989
+
990
+ # 3. Clear temporary attributes
991
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
992
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
993
+
994
+ for obj in [runner, getattr(runner, 'dit', None), getattr(runner, 'vae', None)]:
995
+ if obj is None:
996
+ continue
997
+
998
+ actual_obj = obj.dit_model if hasattr(obj, 'dit_model') else obj
999
+
1000
+ for attr in temp_attrs:
1001
+ if hasattr(actual_obj, attr):
1002
+ delattr(actual_obj, attr)
1003
+ cleaned_items += 1
1004
+
1005
+ if debug:
1006
+ debug.end_timer("runtime_cache_clear", f"clear_runtime_caches() completion")
1007
+
1008
+ return cleaned_items
1009
+
1010
+
1011
+ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
1012
+ """
1013
+ Cleanup DiT model and BlockSwap state after upscaling phase.
1014
+ Called at the end of upscale_all_batches when DiT is no longer needed.
1015
+
1016
+ Args:
1017
+ runner: Runner instance containing DiT model
1018
+ debug: Debug instance for logging
1019
+ cache_model: If True, move DiT to offload_device; if False, delete completely
1020
+ """
1021
+ if not runner or not hasattr(runner, 'dit'):
1022
+ return
1023
+
1024
+ if debug:
1025
+ debug.log("Cleaning up DiT components", category="cleanup")
1026
+
1027
+ # 1. Clear DiT-specific runtime caches first
1028
+ if hasattr(runner, 'dit'):
1029
+ model = runner.dit
1030
+ if hasattr(model, 'dit_model'): # Handle wrapper
1031
+ model = model.dit_model
1032
+
1033
+ # Clear RoPE caches
1034
+ rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
1035
+ if rope_cleared > 0 and debug:
1036
+ debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
1037
+
1038
+ # Clear DiT temporary attributes
1039
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
1040
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
1041
+
1042
+ actual_obj = model.dit_model if hasattr(model, 'dit_model') else model
1043
+ for attr in temp_attrs:
1044
+ if hasattr(actual_obj, attr):
1045
+ delattr(actual_obj, attr)
1046
+
1047
+ # 2. Handle model offloading (for caching or before deletion)
1048
+ try:
1049
+ param_device = next(runner.dit.parameters()).device
1050
+
1051
+ # Move model off GPU if needed
1052
+ if param_device.type not in ['meta', 'cpu']:
1053
+ # MPS: skip CPU movement before deletion (unified memory, just causes sync)
1054
+ if param_device.type == 'mps' and not cache_model:
1055
+ if debug:
1056
+ debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup")
1057
+ else:
1058
+ offload_target = getattr(runner, '_dit_offload_device', None)
1059
+ if offload_target is None or offload_target == 'none':
1060
+ offload_target = torch.device('cpu')
1061
+ reason = "model caching" if cache_model else "releasing GPU memory"
1062
+ manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
1063
+ debug=debug, reason=reason, runner=runner)
1064
+ elif param_device.type == 'meta' and debug:
1065
+ debug.log("DiT on meta device - keeping structure for cache", category="cleanup")
1066
+ except StopIteration:
1067
+ pass
1068
+
1069
+ # 3. Clean BlockSwap after model movement
1070
+ if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
1071
+ # Import here to avoid circular dependency
1072
+ from .blockswap import cleanup_blockswap
1073
+ cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model)
1074
+
1075
+ # 4. Complete cleanup if not caching
1076
+ if not cache_model:
1077
+ release_model_memory(model=runner.dit, debug=debug)
1078
+ runner.dit = None
1079
+ if debug:
1080
+ debug.log("DiT model deleted", category="cleanup")
1081
+
1082
+ # Clear DiT config attributes - not needed when model is not cached (will be recreated)
1083
+ if hasattr(runner, '_dit_compile_args'):
1084
+ delattr(runner, '_dit_compile_args')
1085
+ if hasattr(runner, '_dit_block_swap_config'):
1086
+ delattr(runner, '_dit_block_swap_config')
1087
+ if hasattr(runner, '_dit_attention_mode'):
1088
+ delattr(runner, '_dit_attention_mode')
1089
+
1090
+ # 5. Clear DiT temporary attributes (should be already cleared in materialize_model)
1091
+ runner._dit_checkpoint = None
1092
+ runner._dit_dtype_override = None
1093
+
1094
+ # 6. Clear DiT-related components and temporary attributes
1095
+ runner.sampler = None
1096
+ runner.sampling_timesteps = None
1097
+ runner.schedule = None
1098
+
1099
+
1100
+ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
1101
+ """
1102
+ Cleanup VAE model after decoding phase.
1103
+ Called at the end of decode_all_batches when VAE is no longer needed.
1104
+
1105
+ Args:
1106
+ runner: Runner instance containing VAE model
1107
+ debug: Debug instance for logging
1108
+ cache_model: If True, move VAE to offload_device; if False, delete completely
1109
+ """
1110
+ if not runner or not hasattr(runner, 'vae'):
1111
+ return
1112
+
1113
+ if debug:
1114
+ debug.log("Cleaning up VAE components", category="cleanup")
1115
+
1116
+ # 1. Clear VAE-specific temporary attributes
1117
+ if hasattr(runner, 'vae'):
1118
+ temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
1119
+ '_rope_cache', '_intermediate_cache', '_backward_cache']
1120
+
1121
+ for attr in temp_attrs:
1122
+ if hasattr(runner.vae, attr):
1123
+ delattr(runner.vae, attr)
1124
+
1125
+ # 2. Handle model offloading (for caching or before deletion)
1126
+ try:
1127
+ param_device = next(runner.vae.parameters()).device
1128
+
1129
+ # Move model off GPU if needed
1130
+ if param_device.type not in ['meta', 'cpu']:
1131
+ # MPS: skip CPU movement before deletion (unified memory, just causes sync)
1132
+ if param_device.type == 'mps' and not cache_model:
1133
+ if debug:
1134
+ debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup")
1135
+ else:
1136
+ offload_target = getattr(runner, '_vae_offload_device', None)
1137
+ if offload_target is None or offload_target == 'none':
1138
+ offload_target = torch.device('cpu')
1139
+ reason = "model caching" if cache_model else "releasing GPU memory"
1140
+ manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
1141
+ debug=debug, reason=reason, runner=runner)
1142
+ elif param_device.type == 'meta' and debug:
1143
+ debug.log("VAE on meta device - keeping structure for cache", category="cleanup")
1144
+ except StopIteration:
1145
+ pass
1146
+
1147
+ # 3. Complete cleanup if not caching
1148
+ if not cache_model:
1149
+ release_model_memory(model=runner.vae, debug=debug)
1150
+ runner.vae = None
1151
+ if debug:
1152
+ debug.log("VAE model deleted", category="cleanup")
1153
+
1154
+ # Clear VAE config attributes - not needed when model is not cached (will be recreated)
1155
+ if hasattr(runner, '_vae_compile_args'):
1156
+ delattr(runner, '_vae_compile_args')
1157
+ if hasattr(runner, '_vae_tiling_config'):
1158
+ delattr(runner, '_vae_tiling_config')
1159
+
1160
+ # 3. Clear VAE temporary attributes (should be already cleared in materialize_model)
1161
+ runner._vae_checkpoint = None
1162
+ runner._vae_dtype_override = None
1163
+
1164
+
1165
+ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bool = False, vae_cache: bool = False) -> None:
1166
+ """
1167
+ Complete cleanup of runner and remaining components with independent model caching support.
1168
+ This is a lightweight cleanup for final stage, as model-specific cleanup
1169
+ happens in their respective phases (cleanup_dit, cleanup_vae).
1170
+
1171
+ Args:
1172
+ runner: Runner instance to clean up
1173
+ debug: Debug instance for logging
1174
+ dit_cache: If True, preserve DiT model on offload_device for future runs
1175
+ vae_cache: If True, preserve VAE model on offload_device for future runs
1176
+
1177
+ Behavior:
1178
+ - Can cache DiT and VAE independently for flexible memory management
1179
+ - Preserves _dit_model_name and _vae_model_name when either model is cached for change detection
1180
+ - Clears all temporary attributes and runtime caches
1181
+ - Performs deep memory cleanup only when both models are fully released
1182
+
1183
+ Note:
1184
+ Model name tracking (_dit_model_name, _vae_model_name) is only cleared if neither
1185
+ model is cached, enabling proper model change detection on subsequent runs.
1186
+ """
1187
+ if not runner:
1188
+ return
1189
+
1190
+ if debug:
1191
+ cleanup_type = "partial cleanup" if (dit_cache or vae_cache) else "full cleanup"
1192
+ debug.log(f"Starting {cleanup_type}", category="cleanup")
1193
+
1194
+ # 1. Cleanup any remaining models if they still exist
1195
+ # (This handles cases where phases were skipped or errored)
1196
+ if hasattr(runner, 'dit') and runner.dit is not None:
1197
+ cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache)
1198
+
1199
+ if hasattr(runner, 'vae') and runner.vae is not None:
1200
+ cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache)
1201
+
1202
+ # 2. Clear remaining runtime caches
1203
+ clear_runtime_caches(runner=runner, debug=debug)
1204
+
1205
+ # 3. Clear config and other non-model components when fully releasing runner
1206
+ if not (dit_cache or vae_cache):
1207
+ # Full cleanup - clear config and model tracking
1208
+ runner.config = None
1209
+ runner._dit_model_name = None
1210
+ runner._vae_model_name = None
1211
+
1212
+ # 4. Final memory cleanup
1213
+ clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup")
1214
+
1215
+ # 5. Clear cuBLAS workspaces
1216
+ torch._C._cuda_clearCublasWorkspaces() if hasattr(torch._C, '_cuda_clearCublasWorkspaces') else None
1217
+
1218
+ # Log what models are cached for next run
1219
+ if dit_cache or vae_cache:
1220
+ cached_models = []
1221
+ if dit_cache and hasattr(runner, '_dit_model_name'):
1222
+ cached_models.append(f"DiT ({runner._dit_model_name})")
1223
+ if vae_cache and hasattr(runner, '_vae_model_name'):
1224
+ cached_models.append(f"VAE ({runner._vae_model_name})")
1225
+
1226
+ if cached_models:
1227
+ models_str = " and ".join(cached_models)
1228
+ debug.log(f"Models cached for next run: {models_str}", category="cache", force=True)
1229
+
1230
+ if debug:
1231
+ debug.log(f"Completed {cleanup_type}", category="success")
webui.bat ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ chcp 65001
4
+ set PYTHONUTF8=1
5
+ :: The original source of the webui.bat file is stable-diffusion-webui
6
+ :: Modified and enhanced by Gemini with features for venv management and requirements handling.
7
+
8
+ :: --------- Configuration ---------
9
+ set COMMANDLINE_ARGS=
10
+
11
+ :: Define the application directory (folder name)
12
+ :: Leave empty if the app is in the root directory.
13
+ set APP_DIR=
14
+
15
+ :: Define the name of the Launch application
16
+ set APPLICATION_NAME=app.py
17
+
18
+ :: Define the requirements filename, default is requirements.txt
19
+ set REQUIREMENTS_FILE=requirements.txt
20
+
21
+ :: Define the name of the virtual environment directory
22
+ set VENV_NAME=venv
23
+
24
+ :: Set to 1 to always attempt to update packages from requirements.txt on every launch
25
+ set ALWAYS_UPDATE_REQS=1
26
+ :: ---------------------------------
27
+
28
+ :: --------- Path Setup Logic ---------
29
+ :: Logic to handle paths based on whether APP_DIR is set
30
+ if defined APP_DIR (
31
+ set "TARGET_REQ=%~dp0%APP_DIR%\%REQUIREMENTS_FILE%"
32
+ set "TARGET_SCRIPT=%~dp0%APP_DIR%\%APPLICATION_NAME%"
33
+ echo Working in subdirectory: %APP_DIR%
34
+ ) else (
35
+ set "TARGET_REQ=%~dp0%REQUIREMENTS_FILE%"
36
+ set "TARGET_SCRIPT=%~dp0%APPLICATION_NAME%"
37
+ echo Working in root directory.
38
+ )
39
+ :: ------------------------------------
40
+
41
+ :: Set PYTHON executable if not already defined
42
+ if not defined PYTHON (set PYTHON=python)
43
+ :: Set VENV_DIR using VENV_NAME if not already defined
44
+ if not defined VENV_DIR (set "VENV_DIR=%~dp0%VENV_NAME%")
45
+
46
+ mkdir tmp 2>NUL
47
+
48
+ :: Check if Python is callable
49
+ %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
50
+ if %ERRORLEVEL% == 0 goto :check_pip
51
+ echo Couldn't launch python
52
+ goto :show_stdout_stderr
53
+
54
+ :check_pip
55
+ :: Check if pip is available
56
+ %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
57
+ if %ERRORLEVEL% == 0 goto :start_venv
58
+ :: If pip is not available and PIP_INSTALLER_LOCATION is set, try to install pip
59
+ if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
60
+ %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
61
+ if %ERRORLEVEL% == 0 goto :start_venv
62
+ echo Couldn't install pip
63
+ goto :show_stdout_stderr
64
+
65
+ :start_venv
66
+ :: Skip venv creation/activation if VENV_DIR is explicitly set to "-"
67
+ if ["%VENV_DIR%"] == ["-"] goto :skip_venv_entirely
68
+ :: Skip venv creation/activation if SKIP_VENV is set to "1"
69
+ if ["%SKIP_VENV%"] == ["1"] goto :skip_venv_entirely
70
+
71
+ :: Check if the venv already exists by looking for Python.exe in its Scripts directory
72
+ dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
73
+ if %ERRORLEVEL% == 0 goto :activate_venv_and_maybe_update
74
+
75
+ :: Venv does not exist, create it
76
+ echo Virtual environment not found in "%VENV_DIR%". Creating a new one.
77
+ for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
78
+ echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
79
+ %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
80
+ if %ERRORLEVEL% NEQ 0 (
81
+ echo Unable to create venv in directory "%VENV_DIR%"
82
+ goto :show_stdout_stderr
83
+ )
84
+ echo Venv created.
85
+
86
+ :: Install requirements for the first time if venv was just created
87
+ :: This section handles the initial installation of packages from requirements.txt
88
+ :: immediately after a new virtual environment is created.
89
+ echo Checking for %REQUIREMENTS_FILE% for initial setup...
90
+ if exist "%TARGET_REQ%" (
91
+ echo Found %REQUIREMENTS_FILE% at "%TARGET_REQ%", attempting to install for initial setup...
92
+ call "%VENV_DIR%\Scripts\activate.bat"
93
+ echo Installing packages from %REQUIREMENTS_FILE% ^(initial setup^)...
94
+ "%VENV_DIR%\Scripts\python.exe" -m pip install -r "%TARGET_REQ%"
95
+ if %ERRORLEVEL% NEQ 0 (
96
+ echo Failed to install requirements during initial setup. Please check the output above.
97
+ pause
98
+ goto :show_stdout_stderr_custom_pip_initial
99
+ )
100
+ echo Initial requirements installed successfully.
101
+ call "%VENV_DIR%\Scripts\deactivate.bat"
102
+ ) else (
103
+ echo No %REQUIREMENTS_FILE% found at "%TARGET_REQ%", skipping package installation.
104
+ )
105
+ goto :activate_venv_and_maybe_update
106
+
107
+
108
+ :activate_venv_and_maybe_update
109
+ :: This label is reached if the venv exists or was just created.
110
+ :: Set PYTHON to point to the venv's Python interpreter.
111
+ set PYTHON="%VENV_DIR%\Scripts\Python.exe"
112
+ echo Activating venv: %PYTHON%
113
+
114
+ :: Always update requirements if ALWAYS_UPDATE_REQS is 1
115
+ :: This section allows for updating packages from requirements.txt on every launch
116
+ :: if the ALWAYS_UPDATE_REQS variable is set to 1.
117
+ if defined ALWAYS_UPDATE_REQS (
118
+ if "%ALWAYS_UPDATE_REQS%"=="1" (
119
+ echo ALWAYS_UPDATE_REQS is enabled.
120
+ if exist "%TARGET_REQ%" (
121
+ echo Attempting to update packages from "%TARGET_REQ%"...
122
+ REM No need to call activate.bat here again, PYTHON is already set to the venv's python
123
+ %PYTHON% -m pip install -r "%TARGET_REQ%"
124
+ if %ERRORLEVEL% NEQ 0 (
125
+ echo Failed to update requirements. Please check the output above.
126
+ pause
127
+ goto :endofscript
128
+ )
129
+ echo Requirements updated successfully.
130
+ ) else (
131
+ echo ALWAYS_UPDATE_REQS is enabled, but no %REQUIREMENTS_FILE% found. Skipping update.
132
+ )
133
+ ) else (
134
+ echo ALWAYS_UPDATE_REQS is not enabled or not set to 1. Skipping routine update.
135
+ )
136
+ )
137
+
138
+ goto :launch
139
+
140
+ :skip_venv_entirely
141
+ :: This label is reached if venv usage is explicitly skipped.
142
+ echo Skipping venv.
143
+ goto :launch
144
+
145
+ :launch
146
+ :: Launch the main application
147
+ echo Launching Web UI with arguments: %COMMANDLINE_ARGS% %*
148
+ echo Script path: %TARGET_SCRIPT%
149
+ %PYTHON% "%TARGET_SCRIPT%" %COMMANDLINE_ARGS% %*
150
+ echo Launch finished.
151
+ pause
152
+ exit /b
153
+
154
+ :show_stdout_stderr_custom_pip_initial
155
+ :: Custom error handler for failures during the initial pip install process.
156
+ echo.
157
+ echo exit code ^(pip initial install^): %errorlevel%
158
+ echo Errors during initial pip install. See output above.
159
+ echo.
160
+ echo Launch unsuccessful. Exiting.
161
+ pause
162
+ exit /b
163
+
164
+
165
+ :show_stdout_stderr
166
+ :: General error handler: displays stdout and stderr from the tmp directory.
167
+ echo.
168
+ echo exit code: %errorlevel%
169
+
170
+ for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
171
+ if %size% equ 0 goto :show_stderr
172
+ echo.
173
+ echo stdout:
174
+ type tmp\stdout.txt
175
+
176
+ :show_stderr
177
+ for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
178
+ if %size% equ 0 goto :endofscript
179
+ echo.
180
+ echo stderr:
181
+ type tmp\stderr.txt
182
+
183
+ :endofscript
184
+ echo.
185
+ echo Launch unsuccessful. Exiting.
186
+ pause
187
+ exit /b