Spaces:
Running
Running
init commit
Browse files- .gitignore +6 -0
- README.md +2 -2
- app.py +0 -0
- requirements.txt +39 -0
- src/optimization/blockswap.py +1032 -0
- src/optimization/blockswap.py.bak +938 -0
- src/optimization/memory_manager.py +1285 -0
- src/optimization/memory_manager.py.bak +1231 -0
- webui.bat +187 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vs
|
| 2 |
+
venv
|
| 3 |
+
tmp
|
| 4 |
+
*.pyc
|
| 5 |
+
models
|
| 6 |
+
images
|
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: SeedVR2 Image Upscaler
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 1 |
---
|
| 2 |
title: SeedVR2 Image Upscaler
|
| 3 |
+
emoji: 🖼️
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.50.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu130
|
| 2 |
+
|
| 3 |
+
# Web UI
|
| 4 |
+
gradio==5.50.0
|
| 5 |
+
|
| 6 |
+
# Core numeric / vision
|
| 7 |
+
numpy
|
| 8 |
+
opencv-python
|
| 9 |
+
|
| 10 |
+
# PyTorch
|
| 11 |
+
torch
|
| 12 |
+
torchvision
|
| 13 |
+
|
| 14 |
+
# Hugging Face helper for downloading weights
|
| 15 |
+
huggingface-hub==0.36.0
|
| 16 |
+
|
| 17 |
+
# Utilities
|
| 18 |
+
tqdm
|
| 19 |
+
|
| 20 |
+
# SeedVR2
|
| 21 |
+
psutil
|
| 22 |
+
einops
|
| 23 |
+
diffusers
|
| 24 |
+
rotary-embedding-torch
|
| 25 |
+
omegaconf
|
| 26 |
+
|
| 27 |
+
gguf
|
| 28 |
+
|
| 29 |
+
triton; sys_platform != 'win32'
|
| 30 |
+
triton-windows; sys_platform == 'win32'
|
| 31 |
+
|
| 32 |
+
#
|
| 33 |
+
# flash-attn; sys_platform != 'win32'
|
| 34 |
+
# sageattention; sys_platform != 'win32'
|
| 35 |
+
#
|
| 36 |
+
# if sys_platform == 'win32'
|
| 37 |
+
# https://huggingface.co/lldacing/flash-attention-windows-wheel
|
| 38 |
+
# https://huggingface.co/ussoewwin/Flash-Attention-2_for_Windows
|
| 39 |
+
# https://github.com/woct0rdho/SageAttention
|
src/optimization/blockswap.py
ADDED
|
@@ -0,0 +1,1032 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BlockSwap Module for SeedVR2
|
| 3 |
+
|
| 4 |
+
This module implements dynamic block swapping between GPU and CPU memory
|
| 5 |
+
to enable running large models on limited VRAM systems.
|
| 6 |
+
|
| 7 |
+
Key Features:
|
| 8 |
+
- Dynamic transformer block offloading during inference
|
| 9 |
+
- Non-blocking GPU transfers for optimal performance
|
| 10 |
+
- RoPE computation fallback to CPU on OOM
|
| 11 |
+
- Minimal performance overhead with intelligent caching
|
| 12 |
+
- I/O component offloading for maximum memory savings
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
import types
|
| 17 |
+
import torch
|
| 18 |
+
import weakref
|
| 19 |
+
|
| 20 |
+
from typing import Dict, Any, List, Optional
|
| 21 |
+
from .memory_manager import clear_memory
|
| 22 |
+
from .compatibility import call_rope_with_stability
|
| 23 |
+
from ..common.distributed import get_device
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_blockswap_enabled(config: Optional[Dict[str, Any]]) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Check if BlockSwap configuration indicates BlockSwap should be enabled.
|
| 29 |
+
|
| 30 |
+
BlockSwap is enabled if either blocks_to_swap > 0 OR swap_io_components is True.
|
| 31 |
+
This is the authoritative function for determining BlockSwap status from configuration.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
config: BlockSwap configuration dictionary with optional keys:
|
| 35 |
+
- blocks_to_swap: Number of blocks to offload (0 = disabled)
|
| 36 |
+
- swap_io_components: Whether to offload I/O components
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
True if BlockSwap should be active, False otherwise
|
| 40 |
+
"""
|
| 41 |
+
if not config:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
blocks_to_swap = config.get("blocks_to_swap", 0)
|
| 45 |
+
swap_io_components = config.get("swap_io_components", False)
|
| 46 |
+
|
| 47 |
+
return blocks_to_swap > 0 or swap_io_components
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def validate_blockswap_config(
|
| 51 |
+
block_swap_config: Optional[Dict[str, Any]],
|
| 52 |
+
dit_device: 'torch.device',
|
| 53 |
+
dit_offload_device: Optional['torch.device'],
|
| 54 |
+
debug: 'Debug'
|
| 55 |
+
) -> Optional[Dict[str, Any]]:
|
| 56 |
+
"""
|
| 57 |
+
Validate and potentially modify BlockSwap configuration.
|
| 58 |
+
|
| 59 |
+
Performs platform-specific validation and configuration adjustment:
|
| 60 |
+
- On macOS (MPS): Auto-disables BlockSwap since unified memory makes it meaningless
|
| 61 |
+
- On other platforms: Validates that offload_device is properly configured
|
| 62 |
+
|
| 63 |
+
This is the single authoritative validation point for BlockSwap configuration,
|
| 64 |
+
called early in configure_runner() before any model loading.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
block_swap_config: BlockSwap configuration dictionary (may be None)
|
| 68 |
+
dit_device: Target device for DiT model inference
|
| 69 |
+
dit_offload_device: Device for offloading DiT blocks (may be None)
|
| 70 |
+
debug: Debug instance for logging warnings/errors
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Validated/modified block_swap_config (may be None or modified copy)
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
ValueError: If BlockSwap is enabled but offload_device is invalid (non-MPS only)
|
| 77 |
+
"""
|
| 78 |
+
if not is_blockswap_enabled(block_swap_config):
|
| 79 |
+
return block_swap_config
|
| 80 |
+
|
| 81 |
+
blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
|
| 82 |
+
swap_io_components = block_swap_config.get("swap_io_components", False)
|
| 83 |
+
|
| 84 |
+
# Check for macOS unified memory - BlockSwap is meaningless there
|
| 85 |
+
if dit_device.type == "mps":
|
| 86 |
+
debug.log(
|
| 87 |
+
f"BlockSwap disabled: macOS uses unified memory (no separate VRAM/RAM). "
|
| 88 |
+
f"Ignoring blocks_to_swap={blocks_to_swap}, swap_io_components={swap_io_components}",
|
| 89 |
+
level="WARNING", category="blockswap", force=True
|
| 90 |
+
)
|
| 91 |
+
# Return disabled config
|
| 92 |
+
return {
|
| 93 |
+
**block_swap_config,
|
| 94 |
+
"blocks_to_swap": 0,
|
| 95 |
+
"swap_io_components": False
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Validate offload_device is set and different from dit_device
|
| 99 |
+
offload_device_valid = (
|
| 100 |
+
dit_offload_device is not None and
|
| 101 |
+
str(dit_offload_device) != str(dit_device)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if not offload_device_valid:
|
| 105 |
+
config_details = []
|
| 106 |
+
if blocks_to_swap > 0:
|
| 107 |
+
config_details.append(f"blocks_to_swap={blocks_to_swap}")
|
| 108 |
+
if swap_io_components:
|
| 109 |
+
config_details.append("swap_io_components=True")
|
| 110 |
+
|
| 111 |
+
offload_str = str(dit_offload_device) if dit_offload_device else "none"
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"BlockSwap enabled ({', '.join(config_details)}) but dit_offload_device is invalid. "
|
| 114 |
+
f"Current: device='{dit_device}', dit_offload_device='{offload_str}'. "
|
| 115 |
+
f"BlockSwap requires offload_device on the DiT Model to be set and different from device. "
|
| 116 |
+
f"Set --dit_offload_device cpu or disable BlockSwap."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return block_swap_config
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Timing helpers marked to skip torch.compile tracing
|
| 123 |
+
# These functions are excluded from Dynamo's graph tracing to avoid warnings
|
| 124 |
+
# about non-traceable builtins like time.time(), but they still execute normally
|
| 125 |
+
@torch._dynamo.disable
|
| 126 |
+
def _get_swap_start_time(debug, enabled: bool) -> Optional[float]:
|
| 127 |
+
"""Get start time for swap operation if debug is enabled."""
|
| 128 |
+
return time.time() if debug and enabled else None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch._dynamo.disable
|
| 132 |
+
def _log_swap_timing(debug, t_start: Optional[float], component_id, component_type: str) -> None:
|
| 133 |
+
"""Log swap timing if start time was captured."""
|
| 134 |
+
if debug and t_start is not None:
|
| 135 |
+
debug.log_swap_time(
|
| 136 |
+
component_id=component_id,
|
| 137 |
+
duration=time.time() - t_start,
|
| 138 |
+
component_type=component_type
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_module_memory_mb(module: torch.nn.Module) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Calculate memory usage of a module in MB.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
module: PyTorch module to measure
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Memory usage in megabytes
|
| 151 |
+
"""
|
| 152 |
+
total_bytes = sum(
|
| 153 |
+
param.nelement() * param.element_size()
|
| 154 |
+
for param in module.parameters()
|
| 155 |
+
if param.data is not None
|
| 156 |
+
)
|
| 157 |
+
return total_bytes / (1024 * 1024)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def apply_block_swap_to_dit(
|
| 161 |
+
runner: 'VideoDiffusionInfer',
|
| 162 |
+
block_swap_config: Dict[str, Any],
|
| 163 |
+
debug: 'Debug'
|
| 164 |
+
) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Apply block swapping configuration to a DiT model with OOM protection.
|
| 167 |
+
|
| 168 |
+
This is the main entry point for configuring block swapping on a model.
|
| 169 |
+
Handles block selection, I/O component offloading, device placement, and
|
| 170 |
+
forward method wrapping for dynamic memory management.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
runner: VideoDiffusionInfer instance containing the model
|
| 174 |
+
block_swap_config: Configuration dictionary with keys:
|
| 175 |
+
- blocks_to_swap: Number of blocks to swap (from the start)
|
| 176 |
+
- swap_io_components: Whether to offload I/O components
|
| 177 |
+
- enable_debug: Whether to enable debug logging
|
| 178 |
+
- offload_device: Device to offload to (default: 'cpu')
|
| 179 |
+
debug: Debug instance for logging (required)
|
| 180 |
+
"""
|
| 181 |
+
# Early return if BlockSwap not enabled
|
| 182 |
+
if not is_blockswap_enabled(block_swap_config):
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
|
| 186 |
+
swap_io_components = block_swap_config.get("swap_io_components", False)
|
| 187 |
+
|
| 188 |
+
# Early return only if both block swap and I/O swap are disabled
|
| 189 |
+
if blocks_to_swap <= 0 and not swap_io_components:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
if debug is None:
|
| 193 |
+
if hasattr(runner, 'debug') and runner.debug is not None:
|
| 194 |
+
debug = runner.debug
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError("Debug instance must be provided to apply_block_swap_to_dit")
|
| 197 |
+
|
| 198 |
+
debug.start_timer("apply_blockswap")
|
| 199 |
+
|
| 200 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 201 |
+
model = runner.dit
|
| 202 |
+
if hasattr(model, "dit_model"):
|
| 203 |
+
model = model.dit_model
|
| 204 |
+
|
| 205 |
+
# Determine devices
|
| 206 |
+
if hasattr(runner, '_dit_device'):
|
| 207 |
+
device = runner._dit_device
|
| 208 |
+
else:
|
| 209 |
+
device = get_device()
|
| 210 |
+
offload_device = block_swap_config.get("offload_device", torch.device('cpu'))
|
| 211 |
+
|
| 212 |
+
# Validate model structure
|
| 213 |
+
if not hasattr(model, "blocks"):
|
| 214 |
+
debug.log("Model doesn't have 'blocks' attribute for BlockSwap", level="ERROR", category="blockswap", force=True)
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
total_blocks = len(model.blocks)
|
| 218 |
+
|
| 219 |
+
# Clamp blocks_to_swap to available blocks BEFORE logging
|
| 220 |
+
effective_blocks = min(blocks_to_swap, total_blocks) if blocks_to_swap > 0 else 0
|
| 221 |
+
|
| 222 |
+
# Log configuration clearly based on what's enabled
|
| 223 |
+
block_text = "block" if effective_blocks <= 1 else "blocks"
|
| 224 |
+
if effective_blocks > 0 and swap_io_components:
|
| 225 |
+
debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} + I/O components offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
|
| 226 |
+
elif effective_blocks > 0:
|
| 227 |
+
debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
|
| 228 |
+
elif swap_io_components:
|
| 229 |
+
debug.log(f"BlockSwap: I/O components offloaded to {str(offload_device).upper()} (0/{total_blocks} blocks swapped)", category="blockswap", force=True)
|
| 230 |
+
|
| 231 |
+
# Configure model with blockswap attributes
|
| 232 |
+
if blocks_to_swap > 0:
|
| 233 |
+
model.blocks_to_swap = effective_blocks - 1 # Convert to 0-indexed
|
| 234 |
+
else:
|
| 235 |
+
# No block swapping, set to -1 so no blocks match the swap condition
|
| 236 |
+
model.blocks_to_swap = -1
|
| 237 |
+
|
| 238 |
+
model.main_device = device
|
| 239 |
+
model.offload_device = offload_device
|
| 240 |
+
|
| 241 |
+
# Configure I/O components
|
| 242 |
+
io_config = _configure_io_components(model, device, offload_device,
|
| 243 |
+
swap_io_components, debug)
|
| 244 |
+
memory_stats = _configure_blocks(model, device, offload_device, debug)
|
| 245 |
+
memory_stats['io_components'] = io_config['components']
|
| 246 |
+
memory_stats['io_memory_mb'] = io_config['memory_mb']
|
| 247 |
+
memory_stats['gpu_components'] = io_config['gpu_components']
|
| 248 |
+
memory_stats['io_gpu_memory_mb'] = io_config['gpu_memory_mb']
|
| 249 |
+
|
| 250 |
+
# Log memory summary
|
| 251 |
+
_log_memory_summary(memory_stats, offload_device, device, swap_io_components,
|
| 252 |
+
debug)
|
| 253 |
+
|
| 254 |
+
# Initialize Nunchaku-style async management object
|
| 255 |
+
if blocks_to_swap > 0:
|
| 256 |
+
# normalize device objects
|
| 257 |
+
if isinstance(device, str):
|
| 258 |
+
device = torch.device(device)
|
| 259 |
+
model._swap_stream = torch.cuda.Stream(device=device)
|
| 260 |
+
model._block_ready_events = {}
|
| 261 |
+
|
| 262 |
+
# Preload first swapped block to seed pipeline (non-blocking on swap_stream)
|
| 263 |
+
try:
|
| 264 |
+
first_idx = 0
|
| 265 |
+
if first_idx <= model.blocks_to_swap:
|
| 266 |
+
with torch.cuda.stream(model._swap_stream):
|
| 267 |
+
model.blocks[first_idx].to(device, non_blocking=True)
|
| 268 |
+
ev = torch.cuda.Event(blocking=False)
|
| 269 |
+
ev.record(model._swap_stream) # record on swap_stream -> event gets device-bound here
|
| 270 |
+
model._block_ready_events[first_idx] = ev
|
| 271 |
+
except Exception as e:
|
| 272 |
+
debug.log(f"Failed to initialize swap-stream prefetch: {e}", level="WARNING", category="blockswap", force=True)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Wrap block forward methods for dynamic swapping (only if blocks_to_swap > 0)
|
| 276 |
+
if blocks_to_swap > 0:
|
| 277 |
+
for b, block in enumerate(model.blocks):
|
| 278 |
+
if b <= model.blocks_to_swap:
|
| 279 |
+
_wrap_block_forward(block, b, model, debug)
|
| 280 |
+
|
| 281 |
+
# Patch RoPE modules for robust error handling
|
| 282 |
+
_patch_rope_for_blockswap(model, debug)
|
| 283 |
+
|
| 284 |
+
# Mark BlockSwap as active
|
| 285 |
+
runner._blockswap_active = True
|
| 286 |
+
|
| 287 |
+
# Store configuration for debugging and cleanup
|
| 288 |
+
model._block_swap_config = {
|
| 289 |
+
"blocks_swapped": blocks_to_swap,
|
| 290 |
+
"swap_io_components": swap_io_components,
|
| 291 |
+
"total_blocks": total_blocks,
|
| 292 |
+
"offload_device": offload_device,
|
| 293 |
+
"main_device": device,
|
| 294 |
+
"offload_memory": memory_stats['offload_memory'],
|
| 295 |
+
"main_memory": memory_stats['main_memory']
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
# Protect model from being moved entirely
|
| 299 |
+
_protect_model_from_move(model, runner, debug)
|
| 300 |
+
|
| 301 |
+
debug.log("BlockSwap configuration complete", category="success")
|
| 302 |
+
debug.end_timer("apply_blockswap", "BlockSwap configuration application")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _configure_io_components(
|
| 306 |
+
model: torch.nn.Module,
|
| 307 |
+
device: torch.device,
|
| 308 |
+
offload_device: torch.device,
|
| 309 |
+
swap_io_components: bool,
|
| 310 |
+
debug: 'Debug'
|
| 311 |
+
) -> Dict[str, Any]:
|
| 312 |
+
"""
|
| 313 |
+
Configure I/O component placement and wrapping with memory tracking.
|
| 314 |
+
|
| 315 |
+
Handles all non-block modules (embeddings, normalization layers, etc.) by
|
| 316 |
+
either keeping them on GPU or offloading them with dynamic swapping wrappers.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
model: DiT model containing named children to configure
|
| 320 |
+
device: Main computation device (typically GPU)
|
| 321 |
+
offload_device: Device for offloaded components (typically CPU)
|
| 322 |
+
swap_io_components: If True, offload I/O components with dynamic swapping
|
| 323 |
+
debug: Debug instance for logging (required)
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Dictionary containing:
|
| 327 |
+
- components: List of offloaded component names
|
| 328 |
+
- memory_mb: Total memory of offloaded components in MB
|
| 329 |
+
- gpu_components: List of components remaining on GPU
|
| 330 |
+
- gpu_memory_mb: Total memory of GPU components in MB
|
| 331 |
+
"""
|
| 332 |
+
io_components_offloaded = []
|
| 333 |
+
io_components_on_gpu = []
|
| 334 |
+
io_memory_mb = 0.0
|
| 335 |
+
io_gpu_memory_mb = 0.0
|
| 336 |
+
|
| 337 |
+
# Check for pin memory condition
|
| 338 |
+
use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")
|
| 339 |
+
|
| 340 |
+
# Handle I/O modules with dynamic swapping
|
| 341 |
+
for name, module in model.named_children():
|
| 342 |
+
if name != "blocks":
|
| 343 |
+
module_memory = get_module_memory_mb(module)
|
| 344 |
+
|
| 345 |
+
if swap_io_components:
|
| 346 |
+
module.to(offload_device)
|
| 347 |
+
|
| 348 |
+
# Enable Pin Memory for I/O components
|
| 349 |
+
if use_pin_memory:
|
| 350 |
+
for p in module.parameters():
|
| 351 |
+
if not p.is_pinned():
|
| 352 |
+
p.data = p.data.pin_memory()
|
| 353 |
+
for buf in module.buffers():
|
| 354 |
+
if not buf.is_pinned():
|
| 355 |
+
buf.data = buf.data.pin_memory()
|
| 356 |
+
|
| 357 |
+
_wrap_io_forward(module, name, model, debug)
|
| 358 |
+
io_components_offloaded.append(name)
|
| 359 |
+
io_memory_mb += module_memory
|
| 360 |
+
debug.log(f"{name} → {str(offload_device).upper()} ({module_memory:.2f}MB, dynamic swapping)", category="blockswap", indent_level=1)
|
| 361 |
+
else:
|
| 362 |
+
module.to(device)
|
| 363 |
+
io_components_on_gpu.append(name)
|
| 364 |
+
io_gpu_memory_mb += module_memory
|
| 365 |
+
debug.log(f"{name} → {str(device).upper()} ({module_memory:.2f}MB)", category="blockswap", indent_level=1)
|
| 366 |
+
|
| 367 |
+
return {
|
| 368 |
+
'components': io_components_offloaded,
|
| 369 |
+
'memory_mb': io_memory_mb,
|
| 370 |
+
'gpu_components': io_components_on_gpu,
|
| 371 |
+
'gpu_memory_mb': io_gpu_memory_mb
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _configure_blocks(
|
| 376 |
+
model: torch.nn.Module,
|
| 377 |
+
device: torch.device,
|
| 378 |
+
offload_device: torch.device,
|
| 379 |
+
debug: 'Debug'
|
| 380 |
+
) -> Dict[str, float]:
|
| 381 |
+
"""
|
| 382 |
+
Configure transformer block placement and calculate memory statistics.
|
| 383 |
+
|
| 384 |
+
Moves blocks to their designated devices based on model.blocks_to_swap
|
| 385 |
+
attribute. Blocks with index <= blocks_to_swap go to offload device,
|
| 386 |
+
others stay on main device.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
model: DiT model with blocks attribute and blocks_to_swap configured
|
| 390 |
+
device: Main computation device for non-swapped blocks
|
| 391 |
+
offload_device: Device for swapped blocks
|
| 392 |
+
debug: Debug instance for logging (required)
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
Dictionary containing:
|
| 396 |
+
- offload_memory: Total memory of offloaded blocks in MB
|
| 397 |
+
- main_memory: Total memory of blocks on main device in MB
|
| 398 |
+
- io_components: Empty list (populated by caller)
|
| 399 |
+
"""
|
| 400 |
+
total_offload_memory = 0.0
|
| 401 |
+
total_main_memory = 0.0
|
| 402 |
+
|
| 403 |
+
# Check if we should pin memory (if offloading to CPU)
|
| 404 |
+
# Nunchaku uses pinned memory for faster async transfers
|
| 405 |
+
use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")
|
| 406 |
+
|
| 407 |
+
# Move blocks based on swap configuration
|
| 408 |
+
for b, block in enumerate(model.blocks):
|
| 409 |
+
block_memory = get_module_memory_mb(block)
|
| 410 |
+
|
| 411 |
+
if b > model.blocks_to_swap:
|
| 412 |
+
block.to(device)
|
| 413 |
+
total_main_memory += block_memory
|
| 414 |
+
else:
|
| 415 |
+
block.to(offload_device, non_blocking=False)
|
| 416 |
+
total_offload_memory += block_memory
|
| 417 |
+
|
| 418 |
+
# Enable Pin Memory optimization for CPU Offload transfer speed
|
| 419 |
+
if use_pin_memory:
|
| 420 |
+
for p in block.parameters():
|
| 421 |
+
if not p.is_pinned():
|
| 422 |
+
p.data = p.data.pin_memory()
|
| 423 |
+
for buf in block.buffers():
|
| 424 |
+
if not buf.is_pinned():
|
| 425 |
+
buf.data = buf.data.pin_memory()
|
| 426 |
+
|
| 427 |
+
# Ensure all buffers match their containing module's device
|
| 428 |
+
for b, block in enumerate(model.blocks):
|
| 429 |
+
target_device = device if b > model.blocks_to_swap else offload_device
|
| 430 |
+
for name, buffer in block.named_buffers():
|
| 431 |
+
if buffer.device != torch.device(target_device):
|
| 432 |
+
# Apply pinning if needed
|
| 433 |
+
if use_pin_memory and target_device.type == "cpu" and not buffer.is_pinned():
|
| 434 |
+
buffer.data = buffer.data.pin_memory()
|
| 435 |
+
buffer.data = buffer.data.to(target_device, non_blocking=False)
|
| 436 |
+
|
| 437 |
+
return {
|
| 438 |
+
"offload_memory": total_offload_memory,
|
| 439 |
+
"main_memory": total_main_memory,
|
| 440 |
+
"io_components": [] # Will be populated by caller
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _log_memory_summary(
|
| 445 |
+
memory_stats: Dict[str, float],
|
| 446 |
+
offload_device: torch.device,
|
| 447 |
+
device: torch.device,
|
| 448 |
+
swap_io_components: bool,
|
| 449 |
+
debug: 'Debug'
|
| 450 |
+
) -> None:
|
| 451 |
+
"""
|
| 452 |
+
Log comprehensive memory usage summary for BlockSwap configuration.
|
| 453 |
+
|
| 454 |
+
Displays detailed breakdown of memory distribution across devices,
|
| 455 |
+
including transformer blocks and I/O components.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
memory_stats: Dictionary containing:
|
| 459 |
+
- offload_memory: Memory offloaded from blocks (MB)
|
| 460 |
+
- main_memory: Memory remaining on main device (MB)
|
| 461 |
+
- io_memory_mb: Memory from offloaded I/O components (MB)
|
| 462 |
+
- io_gpu_memory_mb: Memory from I/O components on GPU (MB)
|
| 463 |
+
offload_device: Device used for offloading
|
| 464 |
+
device: Main computation device
|
| 465 |
+
swap_io_components: Whether I/O components are being swapped
|
| 466 |
+
debug: Debug instance for logging (required)
|
| 467 |
+
"""
|
| 468 |
+
debug.log("BlockSwap memory configuration:", category="blockswap")
|
| 469 |
+
|
| 470 |
+
# Log transformer blocks memory
|
| 471 |
+
blocks_offloaded = memory_stats['offload_memory']
|
| 472 |
+
blocks_on_gpu = memory_stats['main_memory']
|
| 473 |
+
|
| 474 |
+
offload_str = str(offload_device)
|
| 475 |
+
device_str = str(device)
|
| 476 |
+
|
| 477 |
+
if blocks_on_gpu == 0:
|
| 478 |
+
debug.log(f"Transformer blocks: {blocks_offloaded:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
|
| 479 |
+
else:
|
| 480 |
+
debug.log(f"Transformer blocks: {blocks_on_gpu:.2f}MB on {device_str}, {blocks_offloaded:.2f}MB on {offload_str}", category="blockswap", indent_level=1)
|
| 481 |
+
|
| 482 |
+
# Always log I/O components (whether swapping or not)
|
| 483 |
+
io_memory = memory_stats.get('io_memory_mb', 0.0)
|
| 484 |
+
io_gpu_memory = memory_stats.get('io_gpu_memory_mb', 0.0)
|
| 485 |
+
|
| 486 |
+
if swap_io_components and io_memory > 0:
|
| 487 |
+
io_components = memory_stats.get('io_components', [])
|
| 488 |
+
debug.log(f"I/O components: {io_memory:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
|
| 489 |
+
debug.log(f"{', '.join(io_components)}", category="blockswap", indent_level=2)
|
| 490 |
+
elif io_gpu_memory > 0:
|
| 491 |
+
io_gpu_components = memory_stats.get('gpu_components', [])
|
| 492 |
+
debug.log(f"I/O components: {io_gpu_memory:.2f}MB on {device_str}", category="blockswap", indent_level=1)
|
| 493 |
+
debug.log(f"{', '.join(io_gpu_components)}", category="blockswap", indent_level=2)
|
| 494 |
+
|
| 495 |
+
# Log total VRAM savings
|
| 496 |
+
total_offloaded = blocks_offloaded + (io_memory if swap_io_components else 0)
|
| 497 |
+
if total_offloaded > 0:
|
| 498 |
+
debug.log(f"Total VRAM saved: {total_offloaded:.2f}MB (~{total_offloaded/1024:.2f}GB)", category="blockswap", indent_level=1)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _wrap_block_forward(
|
| 502 |
+
block: torch.nn.Module,
|
| 503 |
+
block_idx: int,
|
| 504 |
+
model: torch.nn.Module,
|
| 505 |
+
debug: 'Debug'
|
| 506 |
+
) -> None:
|
| 507 |
+
"""
|
| 508 |
+
Wrap individual transformer block forward for dynamic device swapping.
|
| 509 |
+
|
| 510 |
+
Implements Nunchaku-style pipelining: Prefetch Next -> Compute Current -> Offload Current.
|
| 511 |
+
https://github.com/nunchaku-tech/nunchaku/blob/main/nunchaku/models/utils.py
|
| 512 |
+
|
| 513 |
+
Creates a wrapped forward method that automatically:
|
| 514 |
+
1. Moves block to GPU before computation
|
| 515 |
+
2. Executes original forward pass
|
| 516 |
+
3. Moves block back to offload device after computation
|
| 517 |
+
4. Logs timing and manages memory pressure
|
| 518 |
+
|
| 519 |
+
Uses weak references to prevent memory leaks from closure retention.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
block: Individual transformer block to wrap
|
| 523 |
+
block_idx: Index of this block in model.blocks
|
| 524 |
+
model: Parent DiT model (used for device references)
|
| 525 |
+
debug: Debug instance for logging (required)
|
| 526 |
+
"""
|
| 527 |
+
if hasattr(block, '_original_forward'):
|
| 528 |
+
return # Already wrapped
|
| 529 |
+
|
| 530 |
+
# Store original forward method
|
| 531 |
+
original_forward = block.forward
|
| 532 |
+
|
| 533 |
+
# Create weak references
|
| 534 |
+
model_ref = weakref.ref(model)
|
| 535 |
+
debug_ref = weakref.ref(debug) if debug is not None else (lambda: None)
|
| 536 |
+
|
| 537 |
+
# Store block_idx on the block itself to avoid closure issues
|
| 538 |
+
block._block_idx = block_idx
|
| 539 |
+
|
| 540 |
+
def wrapped_forward(self, *args, **kwargs):
|
| 541 |
+
# Retrieve weak references
|
| 542 |
+
model = model_ref()
|
| 543 |
+
debug = debug_ref()
|
| 544 |
+
|
| 545 |
+
if not model:
|
| 546 |
+
# Model has been garbage collected, fall back to original
|
| 547 |
+
return original_forward(*args, **kwargs)
|
| 548 |
+
|
| 549 |
+
# Check if block swap is active for this block
|
| 550 |
+
if hasattr(model, 'blocks_to_swap') and self._block_idx <= model.blocks_to_swap:
|
| 551 |
+
# Use dynamo-disabled helper to get start time (avoids compilation warnings)
|
| 552 |
+
t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
|
| 553 |
+
|
| 554 |
+
# Only move to GPU if necessary
|
| 555 |
+
current_device = next(self.parameters()).device
|
| 556 |
+
target_device = torch.device(model.main_device)
|
| 557 |
+
|
| 558 |
+
# 1. Ensure CURRENT block is ready on GPU
|
| 559 |
+
# Check if we have a prefetch event waiting
|
| 560 |
+
if hasattr(model, '_block_ready_events') and self._block_idx in model._block_ready_events:
|
| 561 |
+
# Wait for the swap stream to finish moving this block
|
| 562 |
+
torch.cuda.current_stream().wait_event(model._block_ready_events[self._block_idx])
|
| 563 |
+
# Cleanup event
|
| 564 |
+
del model._block_ready_events[self._block_idx]
|
| 565 |
+
elif current_device != target_device:
|
| 566 |
+
# Fallback: First block or missed prefetch, move synchronously (but non-blocking)
|
| 567 |
+
debug.log(f"[blockswap] Block {self._block_idx} missing prefetch event, moving synchronously", level="WARNING", category="blockswap", force=True)
|
| 568 |
+
self.to(model.main_device, non_blocking=True)
|
| 569 |
+
|
| 570 |
+
# 2. Trigger Prefetch for NEXT block (Pipelining)
|
| 571 |
+
# Nunchaku logic: Start moving i+1 while i is computing
|
| 572 |
+
next_idx = self._block_idx + 1
|
| 573 |
+
if next_idx <= model.blocks_to_swap:
|
| 574 |
+
next_block = model.blocks[next_idx]
|
| 575 |
+
# Use the dedicated swap stream
|
| 576 |
+
with torch.cuda.stream(model._swap_stream):
|
| 577 |
+
next_block.to(model.main_device, non_blocking=True)
|
| 578 |
+
# Record event so next iteration knows when to wait
|
| 579 |
+
event = torch.cuda.Event(blocking=False)
|
| 580 |
+
event.record(model._swap_stream)
|
| 581 |
+
model._block_ready_events[next_idx] = event
|
| 582 |
+
|
| 583 |
+
# 3. Execute forward pass (Compute)
|
| 584 |
+
# This runs on the default stream, overlapping with the prefetch above
|
| 585 |
+
output = original_forward(*args, **kwargs)
|
| 586 |
+
|
| 587 |
+
# 4. Offload CURRENT block (Async)
|
| 588 |
+
# We record an event on compute stream to ensure we don't move data while it's being used
|
| 589 |
+
compute_done_event = torch.cuda.Event(blocking=False)
|
| 590 |
+
compute_done_event.record(torch.cuda.current_stream())
|
| 591 |
+
|
| 592 |
+
with torch.cuda.stream(model._swap_stream):
|
| 593 |
+
# Wait for compute to finish before moving memory out
|
| 594 |
+
model._swap_stream.wait_event(compute_done_event)
|
| 595 |
+
# Move back to offload device
|
| 596 |
+
self.to(model.offload_device, non_blocking=True)
|
| 597 |
+
|
| 598 |
+
# Use dynamo-disabled helper to log timing (avoids compilation warnings)
|
| 599 |
+
_log_swap_timing(debug, t_start, self._block_idx, "block (pipelined)")
|
| 600 |
+
|
| 601 |
+
# Only clear cache under memory pressure
|
| 602 |
+
clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
|
| 603 |
+
else:
|
| 604 |
+
output = original_forward(*args, **kwargs)
|
| 605 |
+
|
| 606 |
+
return output
|
| 607 |
+
|
| 608 |
+
# Bind the wrapped function as a method to the block
|
| 609 |
+
block.forward = types.MethodType(wrapped_forward, block)
|
| 610 |
+
|
| 611 |
+
# Store reference to original forward for cleanup
|
| 612 |
+
block._original_forward = original_forward
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _wrap_io_forward(
|
| 616 |
+
module: torch.nn.Module,
|
| 617 |
+
module_name: str,
|
| 618 |
+
model: torch.nn.Module,
|
| 619 |
+
debug: 'Debug'
|
| 620 |
+
) -> None:
|
| 621 |
+
"""
|
| 622 |
+
Wrap I/O component forward for dynamic device swapping.
|
| 623 |
+
|
| 624 |
+
Similar to _wrap_block_forward but for I/O components (embeddings,
|
| 625 |
+
normalization layers, etc.). Handles swapping between GPU and CPU
|
| 626 |
+
during forward passes.
|
| 627 |
+
|
| 628 |
+
Uses weak references to prevent circular dependencies and memory leaks.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
module: I/O component module to wrap
|
| 632 |
+
module_name: Name identifier for logging (e.g., 'x_embedder')
|
| 633 |
+
model: Parent DiT model (used for device references)
|
| 634 |
+
debug: Debug instance for logging (required)
|
| 635 |
+
"""
|
| 636 |
+
if hasattr(module, '_is_io_wrapped') and module._is_io_wrapped:
|
| 637 |
+
debug.log(f"Reusing existing I/O wrapper for {module_name}", category="reuse")
|
| 638 |
+
return # Already wrapped
|
| 639 |
+
|
| 640 |
+
# Store original forward method
|
| 641 |
+
original_forward = module.forward
|
| 642 |
+
|
| 643 |
+
# Create weak references
|
| 644 |
+
model_ref = weakref.ref(model)
|
| 645 |
+
debug_ref = weakref.ref(debug) if debug else lambda: None
|
| 646 |
+
|
| 647 |
+
# Store module name on the module itself
|
| 648 |
+
module._module_name = module_name
|
| 649 |
+
module._original_forward = original_forward
|
| 650 |
+
|
| 651 |
+
def wrapped_io_forward(self, *args, **kwargs):
|
| 652 |
+
# Retrieve weak references
|
| 653 |
+
model = model_ref()
|
| 654 |
+
debug = debug_ref()
|
| 655 |
+
|
| 656 |
+
if not model:
|
| 657 |
+
# Model has been garbage collected, fall back to original
|
| 658 |
+
return self._original_forward(*args, **kwargs)
|
| 659 |
+
|
| 660 |
+
# Use dynamo-disabled helper to get start time (avoids compilation warnings)
|
| 661 |
+
t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
|
| 662 |
+
|
| 663 |
+
# Check current device to avoid unnecessary moves
|
| 664 |
+
current_device = next(self.parameters()).device
|
| 665 |
+
target_device = torch.device(model.main_device)
|
| 666 |
+
|
| 667 |
+
# Move to GPU for computation if needed
|
| 668 |
+
if current_device != target_device:
|
| 669 |
+
self.to(model.main_device, non_blocking=False)
|
| 670 |
+
|
| 671 |
+
# Execute forward pass
|
| 672 |
+
output = self._original_forward(*args, **kwargs)
|
| 673 |
+
|
| 674 |
+
# Move back to offload device
|
| 675 |
+
self.to(model.offload_device, non_blocking=False)
|
| 676 |
+
|
| 677 |
+
# Use dynamo-disabled helper to log timing (avoids compilation warnings)
|
| 678 |
+
_log_swap_timing(debug, t_start, self._module_name, "I/O")
|
| 679 |
+
|
| 680 |
+
# Only clear cache under memory pressure
|
| 681 |
+
clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
|
| 682 |
+
|
| 683 |
+
return output
|
| 684 |
+
|
| 685 |
+
# Bind as a method
|
| 686 |
+
module.forward = types.MethodType(wrapped_io_forward, module)
|
| 687 |
+
module._is_io_wrapped = True
|
| 688 |
+
|
| 689 |
+
# Store module reference for restoration
|
| 690 |
+
if not hasattr(model, '_io_swappers'):
|
| 691 |
+
model._io_swappers = []
|
| 692 |
+
model._io_swappers.append((module, module_name))
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def _patch_rope_for_blockswap(
|
| 696 |
+
model: torch.nn.Module,
|
| 697 |
+
debug: 'Debug'
|
| 698 |
+
) -> None:
|
| 699 |
+
"""
|
| 700 |
+
Patch RoPE (Rotary Position Embedding) modules for device-aware fallback.
|
| 701 |
+
|
| 702 |
+
Adds CPU fallback logic to RoPE modules to handle device mismatch errors
|
| 703 |
+
that can occur during BlockSwap operations. Complements the stability
|
| 704 |
+
wrapper from compatibility.py with device-specific error handling.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
model: DiT model containing RoPE modules to patch
|
| 708 |
+
debug: Debug instance for logging (required)
|
| 709 |
+
"""
|
| 710 |
+
rope_patches = []
|
| 711 |
+
|
| 712 |
+
for name, module in model.named_modules():
|
| 713 |
+
if "rope" in name.lower() and hasattr(module, "get_axial_freqs"):
|
| 714 |
+
# Skip if already wrapped by blockswap
|
| 715 |
+
if hasattr(module, '_blockswap_wrapped') and module._blockswap_wrapped:
|
| 716 |
+
continue
|
| 717 |
+
|
| 718 |
+
# Get current method (might be stability-wrapped)
|
| 719 |
+
current_method = module.get_axial_freqs
|
| 720 |
+
|
| 721 |
+
# Create device-aware wrapper with proper closure handling
|
| 722 |
+
def make_device_aware_wrapper(module_name, current_fn):
|
| 723 |
+
def device_aware_rope_wrapper(self, *args, **kwargs):
|
| 724 |
+
try:
|
| 725 |
+
# Try current method (original or stability-wrapped)
|
| 726 |
+
return current_fn(*args, **kwargs)
|
| 727 |
+
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
|
| 728 |
+
error_msg = str(e).lower()
|
| 729 |
+
# Only handle device/memory specific errors
|
| 730 |
+
if any(x in error_msg for x in ["device", "memory", "allocation"]):
|
| 731 |
+
debug.log(f"RoPE OOM for {module_name}", level="WARNING", category="rope", force=True)
|
| 732 |
+
debug.log(f"Clearing RoPE cache and retrying", category="info", force=True)
|
| 733 |
+
|
| 734 |
+
# Get current device from parameters
|
| 735 |
+
try:
|
| 736 |
+
current_device = next(self.parameters()).device
|
| 737 |
+
except StopIteration:
|
| 738 |
+
# Fallback: use model's main_device if BlockSwap has set it, else use offload_device
|
| 739 |
+
if hasattr(model, 'main_device'):
|
| 740 |
+
current_device = torch.device(model.main_device)
|
| 741 |
+
elif hasattr(model, 'offload_device'):
|
| 742 |
+
current_device = torch.device(model.offload_device)
|
| 743 |
+
|
| 744 |
+
# Try clearing cache first (non-invasive fix)
|
| 745 |
+
if hasattr(current_fn, 'cache_clear'):
|
| 746 |
+
current_fn.cache_clear()
|
| 747 |
+
try:
|
| 748 |
+
# Retry on same device after clearing cache
|
| 749 |
+
return current_fn(*args, **kwargs)
|
| 750 |
+
except Exception as retry_error:
|
| 751 |
+
# Cache clear wasn't enough, need more drastic measures
|
| 752 |
+
debug.log(f"Cache clear insufficient for {module_name}, falling back to CPU", level="WARNING", category="rope", force=True)
|
| 753 |
+
|
| 754 |
+
# Fallback to CPU computation with stability
|
| 755 |
+
self.cpu()
|
| 756 |
+
|
| 757 |
+
try:
|
| 758 |
+
# Use call_rope_with_stability for CPU computation
|
| 759 |
+
# This ensures cache is cleared and autocast disabled
|
| 760 |
+
original_fn = getattr(self, '_original_get_axial_freqs', current_fn)
|
| 761 |
+
result = call_rope_with_stability(original_fn, *args, **kwargs)
|
| 762 |
+
|
| 763 |
+
# Move module back to original device
|
| 764 |
+
self.to(current_device)
|
| 765 |
+
|
| 766 |
+
# Move result to appropriate device if it's a tensor
|
| 767 |
+
if hasattr(result, 'to'):
|
| 768 |
+
target_device = args[0].device if len(args) > 0 and hasattr(args[0], 'device') else current_device
|
| 769 |
+
return result.to(target_device)
|
| 770 |
+
return result
|
| 771 |
+
|
| 772 |
+
except Exception as cpu_error:
|
| 773 |
+
# Always restore device even on error
|
| 774 |
+
self.to(current_device)
|
| 775 |
+
raise cpu_error
|
| 776 |
+
else:
|
| 777 |
+
# Not a device error, let it bubble up
|
| 778 |
+
raise
|
| 779 |
+
|
| 780 |
+
return device_aware_rope_wrapper
|
| 781 |
+
|
| 782 |
+
# Apply wrapper
|
| 783 |
+
module.get_axial_freqs = types.MethodType(
|
| 784 |
+
make_device_aware_wrapper(name, current_method),
|
| 785 |
+
module
|
| 786 |
+
)
|
| 787 |
+
module._blockswap_wrapped = True
|
| 788 |
+
|
| 789 |
+
# Store for cleanup (use original or previously stored)
|
| 790 |
+
original_method = getattr(module, '_original_get_axial_freqs', current_method)
|
| 791 |
+
rope_patches.append((module, original_method))
|
| 792 |
+
|
| 793 |
+
if rope_patches:
|
| 794 |
+
model._rope_patches = rope_patches
|
| 795 |
+
debug.log(f"Patched {len(rope_patches)} RoPE modules with device handling", category="success")
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def _protect_model_from_move(
|
| 799 |
+
model: torch.nn.Module,
|
| 800 |
+
runner: 'VideoDiffusionInfer',
|
| 801 |
+
debug: 'Debug'
|
| 802 |
+
) -> None:
|
| 803 |
+
"""
|
| 804 |
+
Protect model from unintended full device movement during BlockSwap.
|
| 805 |
+
|
| 806 |
+
Wraps model.to() method to prevent other code from accidentally moving
|
| 807 |
+
the entire model to GPU, which would defeat BlockSwap's memory savings.
|
| 808 |
+
Allows movement only when explicitly bypassed via model flag.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
model: DiT model to protect
|
| 812 |
+
runner: VideoDiffusionInfer instance (for active status check)
|
| 813 |
+
debug: Debug instance for logging (required)
|
| 814 |
+
"""
|
| 815 |
+
if not hasattr(model, '_original_to'):
|
| 816 |
+
# Store runner reference as weak reference to avoid circular refs
|
| 817 |
+
model._blockswap_runner_ref = weakref.ref(runner)
|
| 818 |
+
model._original_to = model.to
|
| 819 |
+
|
| 820 |
+
# Define the protected method without closures
|
| 821 |
+
def protected_model_to(self, device, *args, **kwargs):
|
| 822 |
+
# Check if protection is temporarily bypassed for offloading
|
| 823 |
+
# Flag is stored on model itself (not runner) to survive runner recreation
|
| 824 |
+
if getattr(self, "_blockswap_bypass_protection", False):
|
| 825 |
+
# Protection bypassed, allow movement
|
| 826 |
+
if hasattr(self, '_original_to'):
|
| 827 |
+
return self._original_to(device, *args, **kwargs)
|
| 828 |
+
|
| 829 |
+
# Get configured offload device directly from model
|
| 830 |
+
blockswap_offload_device = "cpu" # default
|
| 831 |
+
if hasattr(self, "_block_swap_config"):
|
| 832 |
+
blockswap_offload_device = self._block_swap_config.get("offload_device", "cpu")
|
| 833 |
+
|
| 834 |
+
# Check if BlockSwap is currently active via runner weak reference
|
| 835 |
+
runner_ref = getattr(self, '_blockswap_runner_ref', None)
|
| 836 |
+
blockswap_is_active = False
|
| 837 |
+
if runner_ref:
|
| 838 |
+
runner_obj = runner_ref()
|
| 839 |
+
if runner_obj and hasattr(runner_obj, "_blockswap_active"):
|
| 840 |
+
blockswap_is_active = runner_obj._blockswap_active
|
| 841 |
+
|
| 842 |
+
# Block attempts to move model away from configured offload device when active
|
| 843 |
+
if blockswap_is_active and str(device) != str(blockswap_offload_device):
|
| 844 |
+
# Get debug instance from runner if available
|
| 845 |
+
debug_instance = None
|
| 846 |
+
if runner_ref:
|
| 847 |
+
runner_obj = runner_ref()
|
| 848 |
+
if runner_obj and hasattr(runner_obj, 'debug'):
|
| 849 |
+
debug_instance = runner_obj.debug
|
| 850 |
+
|
| 851 |
+
if debug_instance:
|
| 852 |
+
debug_instance.log(
|
| 853 |
+
f"Blocked attempt to move BlockSwap model from {blockswap_offload_device} to {device}",
|
| 854 |
+
level="WARNING", category="blockswap", force=True
|
| 855 |
+
)
|
| 856 |
+
return self
|
| 857 |
+
|
| 858 |
+
# Allow movement (either bypass is enabled or target is offload device)
|
| 859 |
+
if hasattr(self, '_original_to'):
|
| 860 |
+
return self._original_to(device, *args, **kwargs)
|
| 861 |
+
else:
|
| 862 |
+
# Fallback - shouldn't happen
|
| 863 |
+
return super(type(self), self).to(device, *args, **kwargs)
|
| 864 |
+
|
| 865 |
+
# Bind as a method to the model instance
|
| 866 |
+
model.to = types.MethodType(protected_model_to, model)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def set_blockswap_bypass(runner, bypass: bool, debug):
|
| 870 |
+
"""
|
| 871 |
+
Set or unset bypass flag for BlockSwap protection.
|
| 872 |
+
Used for offloading to temporarily allow model movement.
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
runner: Runner instance with BlockSwap
|
| 876 |
+
bypass: True to bypass protection, False to enforce it
|
| 877 |
+
debug: Debug instance for logging
|
| 878 |
+
"""
|
| 879 |
+
if not hasattr(runner, "_blockswap_active") or not runner._blockswap_active:
|
| 880 |
+
return
|
| 881 |
+
|
| 882 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 883 |
+
model = runner.dit
|
| 884 |
+
if hasattr(model, "dit_model"):
|
| 885 |
+
model = model.dit_model
|
| 886 |
+
|
| 887 |
+
# Store on model so it survives runner recreation during caching
|
| 888 |
+
model._blockswap_bypass_protection = bypass
|
| 889 |
+
|
| 890 |
+
if bypass:
|
| 891 |
+
debug.log("BlockSwap protection disabled to allow model DiT offloading", category="success")
|
| 892 |
+
else:
|
| 893 |
+
debug.log("BlockSwap protection renabled to avoid accidentally offloading the entire DiT model", category="success")
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def cleanup_blockswap(runner, keep_state_for_cache=False):
|
| 897 |
+
"""
|
| 898 |
+
Clean up BlockSwap configuration based on caching mode.
|
| 899 |
+
|
| 900 |
+
When caching (keep_state_for_cache=True):
|
| 901 |
+
- Keep all BlockSwap configuration intact
|
| 902 |
+
- Only mark as inactive for safety during non-inference operations
|
| 903 |
+
|
| 904 |
+
When not caching (keep_state_for_cache=False):
|
| 905 |
+
- Full cleanup of all BlockSwap state
|
| 906 |
+
|
| 907 |
+
Args:
|
| 908 |
+
runner: VideoDiffusionInfer instance to clean up
|
| 909 |
+
keep_state_for_cache: If True, preserve BlockSwap state for reuse
|
| 910 |
+
"""
|
| 911 |
+
# Get debug instance from runner
|
| 912 |
+
if not hasattr(runner, 'debug') or runner.debug is None:
|
| 913 |
+
raise ValueError("Debug instance must be available on runner for cleanup_blockswap")
|
| 914 |
+
|
| 915 |
+
debug = runner.debug
|
| 916 |
+
|
| 917 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 918 |
+
model = runner.dit
|
| 919 |
+
if hasattr(model, "dit_model"):
|
| 920 |
+
model = model.dit_model
|
| 921 |
+
|
| 922 |
+
# Check if there's any BlockSwap state to clean up (check both runner and model)
|
| 923 |
+
has_blockswap_state = (
|
| 924 |
+
hasattr(runner, "_blockswap_active") or
|
| 925 |
+
hasattr(model, "_block_swap_config") or
|
| 926 |
+
hasattr(model, "_blockswap_bypass_protection")
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
if not has_blockswap_state:
|
| 930 |
+
return
|
| 931 |
+
|
| 932 |
+
debug.log("Starting BlockSwap cleanup", category="cleanup")
|
| 933 |
+
|
| 934 |
+
if keep_state_for_cache:
|
| 935 |
+
# Minimal cleanup for caching - just mark as inactive and allow offloading
|
| 936 |
+
# Everything else stays intact for fast reactivation
|
| 937 |
+
if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
|
| 938 |
+
if not getattr(model, "_blockswap_bypass_protection", False):
|
| 939 |
+
set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
|
| 940 |
+
runner._blockswap_active = False
|
| 941 |
+
debug.log("BlockSwap deactivated for caching (configuration preserved)", category="success")
|
| 942 |
+
return
|
| 943 |
+
|
| 944 |
+
# Full cleanup when not caching
|
| 945 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 946 |
+
model = runner.dit
|
| 947 |
+
if hasattr(model, "dit_model"):
|
| 948 |
+
model = model.dit_model
|
| 949 |
+
|
| 950 |
+
# 1. Restore block forward methods
|
| 951 |
+
if hasattr(model, 'blocks'):
|
| 952 |
+
restored_count = 0
|
| 953 |
+
for block in model.blocks:
|
| 954 |
+
if hasattr(block, '_original_forward'):
|
| 955 |
+
block.forward = block._original_forward
|
| 956 |
+
delattr(block, '_original_forward')
|
| 957 |
+
restored_count += 1
|
| 958 |
+
|
| 959 |
+
# Clean up wrapper attributes
|
| 960 |
+
for attr in ['_block_idx', '_model_ref', '_debug_ref', '_blockswap_wrapped']:
|
| 961 |
+
if hasattr(block, attr):
|
| 962 |
+
delattr(block, attr)
|
| 963 |
+
|
| 964 |
+
if restored_count > 0:
|
| 965 |
+
debug.log(f"Restored {restored_count} block forward methods", category="success")
|
| 966 |
+
|
| 967 |
+
# 2. Restore RoPE patches
|
| 968 |
+
if hasattr(model, '_rope_patches'):
|
| 969 |
+
for module, original_method in model._rope_patches:
|
| 970 |
+
module.get_axial_freqs = original_method
|
| 971 |
+
# Clean up wrapper attributes
|
| 972 |
+
for attr in ['_rope_wrapped', '_original_get_axial_freqs']:
|
| 973 |
+
if hasattr(module, attr):
|
| 974 |
+
delattr(module, attr)
|
| 975 |
+
debug.log(f"Restored {len(model._rope_patches)} RoPE methods", category="success")
|
| 976 |
+
delattr(model, '_rope_patches')
|
| 977 |
+
|
| 978 |
+
# 3. Restore I/O component forward methods and move to offload device
|
| 979 |
+
if hasattr(model, '_io_swappers'):
|
| 980 |
+
for module, module_name in model._io_swappers:
|
| 981 |
+
if hasattr(module, '_original_forward'):
|
| 982 |
+
module.forward = module._original_forward
|
| 983 |
+
# Clean up wrapper attributes
|
| 984 |
+
for attr in ['_original_forward', '_model_ref', '_debug_ref',
|
| 985 |
+
'_module_name', '_is_io_wrapped']:
|
| 986 |
+
if hasattr(module, attr):
|
| 987 |
+
delattr(module, attr)
|
| 988 |
+
debug.log(f"Restored {len(model._io_swappers)} I/O components", category="success")
|
| 989 |
+
delattr(model, '_io_swappers')
|
| 990 |
+
|
| 991 |
+
# Move all IO components to offload device during full cleanup
|
| 992 |
+
if hasattr(model, 'offload_device'):
|
| 993 |
+
offload_device = model.offload_device
|
| 994 |
+
moved_count = 0
|
| 995 |
+
for name, module in model.named_children():
|
| 996 |
+
if name != "blocks":
|
| 997 |
+
module.to(offload_device)
|
| 998 |
+
moved_count += 1
|
| 999 |
+
if moved_count > 0:
|
| 1000 |
+
debug.log(f"Moved {moved_count} IO components to offload device", category="success")
|
| 1001 |
+
|
| 1002 |
+
# 4. Restore original .to() method
|
| 1003 |
+
if hasattr(model, '_original_to'):
|
| 1004 |
+
model.to = model._original_to
|
| 1005 |
+
delattr(model, '_original_to')
|
| 1006 |
+
debug.log("Restored original .to() method", category="success")
|
| 1007 |
+
|
| 1008 |
+
# 5. Clean up BlockSwap-specific attributes
|
| 1009 |
+
for attr in ['_blockswap_runner_ref', 'blocks_to_swap', 'main_device',
|
| 1010 |
+
'offload_device']:
|
| 1011 |
+
if hasattr(model, attr):
|
| 1012 |
+
delattr(model, attr)
|
| 1013 |
+
|
| 1014 |
+
# 6. Clean up runner attributes
|
| 1015 |
+
runner._blockswap_active = False
|
| 1016 |
+
|
| 1017 |
+
# Clean up pipelining resources on model (synchronize first)
|
| 1018 |
+
if hasattr(model, '_swap_stream'):
|
| 1019 |
+
try:
|
| 1020 |
+
model._swap_stream.synchronize()
|
| 1021 |
+
except Exception:
|
| 1022 |
+
pass
|
| 1023 |
+
for attr in ['_swap_stream', '_block_ready_events']:
|
| 1024 |
+
if hasattr(model, attr):
|
| 1025 |
+
delattr(model, attr)
|
| 1026 |
+
|
| 1027 |
+
# Remove all config attributes
|
| 1028 |
+
for attr in ['_cached_blockswap_config', '_block_swap_config', '_blockswap_debug']:
|
| 1029 |
+
if hasattr(runner, attr):
|
| 1030 |
+
delattr(runner, attr)
|
| 1031 |
+
|
| 1032 |
+
debug.log("BlockSwap cleanup complete", category="success")
|
src/optimization/blockswap.py.bak
ADDED
|
@@ -0,0 +1,938 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BlockSwap Module for SeedVR2
|
| 3 |
+
|
| 4 |
+
This module implements dynamic block swapping between GPU and CPU memory
|
| 5 |
+
to enable running large models on limited VRAM systems.
|
| 6 |
+
|
| 7 |
+
Key Features:
|
| 8 |
+
- Dynamic transformer block offloading during inference
|
| 9 |
+
- Non-blocking GPU transfers for optimal performance
|
| 10 |
+
- RoPE computation fallback to CPU on OOM
|
| 11 |
+
- Minimal performance overhead with intelligent caching
|
| 12 |
+
- I/O component offloading for maximum memory savings
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
import types
|
| 17 |
+
import torch
|
| 18 |
+
import weakref
|
| 19 |
+
|
| 20 |
+
from typing import Dict, Any, List, Optional
|
| 21 |
+
from .memory_manager import clear_memory
|
| 22 |
+
from .compatibility import call_rope_with_stability
|
| 23 |
+
from ..common.distributed import get_device
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_blockswap_enabled(config: Optional[Dict[str, Any]]) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Check if BlockSwap configuration indicates BlockSwap should be enabled.
|
| 29 |
+
|
| 30 |
+
BlockSwap is enabled if either blocks_to_swap > 0 OR swap_io_components is True.
|
| 31 |
+
This is the authoritative function for determining BlockSwap status from configuration.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
config: BlockSwap configuration dictionary with optional keys:
|
| 35 |
+
- blocks_to_swap: Number of blocks to offload (0 = disabled)
|
| 36 |
+
- swap_io_components: Whether to offload I/O components
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
True if BlockSwap should be active, False otherwise
|
| 40 |
+
"""
|
| 41 |
+
if not config:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
blocks_to_swap = config.get("blocks_to_swap", 0)
|
| 45 |
+
swap_io_components = config.get("swap_io_components", False)
|
| 46 |
+
|
| 47 |
+
return blocks_to_swap > 0 or swap_io_components
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def validate_blockswap_config(
|
| 51 |
+
block_swap_config: Optional[Dict[str, Any]],
|
| 52 |
+
dit_device: 'torch.device',
|
| 53 |
+
dit_offload_device: Optional['torch.device'],
|
| 54 |
+
debug: 'Debug'
|
| 55 |
+
) -> Optional[Dict[str, Any]]:
|
| 56 |
+
"""
|
| 57 |
+
Validate and potentially modify BlockSwap configuration.
|
| 58 |
+
|
| 59 |
+
Performs platform-specific validation and configuration adjustment:
|
| 60 |
+
- On macOS (MPS): Auto-disables BlockSwap since unified memory makes it meaningless
|
| 61 |
+
- On other platforms: Validates that offload_device is properly configured
|
| 62 |
+
|
| 63 |
+
This is the single authoritative validation point for BlockSwap configuration,
|
| 64 |
+
called early in configure_runner() before any model loading.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
block_swap_config: BlockSwap configuration dictionary (may be None)
|
| 68 |
+
dit_device: Target device for DiT model inference
|
| 69 |
+
dit_offload_device: Device for offloading DiT blocks (may be None)
|
| 70 |
+
debug: Debug instance for logging warnings/errors
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Validated/modified block_swap_config (may be None or modified copy)
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
ValueError: If BlockSwap is enabled but offload_device is invalid (non-MPS only)
|
| 77 |
+
"""
|
| 78 |
+
if not is_blockswap_enabled(block_swap_config):
|
| 79 |
+
return block_swap_config
|
| 80 |
+
|
| 81 |
+
blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
|
| 82 |
+
swap_io_components = block_swap_config.get("swap_io_components", False)
|
| 83 |
+
|
| 84 |
+
# Check for macOS unified memory - BlockSwap is meaningless there
|
| 85 |
+
if dit_device.type == "mps":
|
| 86 |
+
debug.log(
|
| 87 |
+
f"BlockSwap disabled: macOS uses unified memory (no separate VRAM/RAM). "
|
| 88 |
+
f"Ignoring blocks_to_swap={blocks_to_swap}, swap_io_components={swap_io_components}",
|
| 89 |
+
level="WARNING", category="blockswap", force=True
|
| 90 |
+
)
|
| 91 |
+
# Return disabled config
|
| 92 |
+
return {
|
| 93 |
+
**block_swap_config,
|
| 94 |
+
"blocks_to_swap": 0,
|
| 95 |
+
"swap_io_components": False
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Validate offload_device is set and different from dit_device
|
| 99 |
+
offload_device_valid = (
|
| 100 |
+
dit_offload_device is not None and
|
| 101 |
+
str(dit_offload_device) != str(dit_device)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if not offload_device_valid:
|
| 105 |
+
config_details = []
|
| 106 |
+
if blocks_to_swap > 0:
|
| 107 |
+
config_details.append(f"blocks_to_swap={blocks_to_swap}")
|
| 108 |
+
if swap_io_components:
|
| 109 |
+
config_details.append("swap_io_components=True")
|
| 110 |
+
|
| 111 |
+
offload_str = str(dit_offload_device) if dit_offload_device else "none"
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"BlockSwap enabled ({', '.join(config_details)}) but dit_offload_device is invalid. "
|
| 114 |
+
f"Current: device='{dit_device}', dit_offload_device='{offload_str}'. "
|
| 115 |
+
f"BlockSwap requires offload_device on the DiT Model to be set and different from device. "
|
| 116 |
+
f"Set --dit_offload_device cpu or disable BlockSwap."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return block_swap_config
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Timing helpers marked to skip torch.compile tracing
|
| 123 |
+
# These functions are excluded from Dynamo's graph tracing to avoid warnings
|
| 124 |
+
# about non-traceable builtins like time.time(), but they still execute normally
|
| 125 |
+
@torch._dynamo.disable
|
| 126 |
+
def _get_swap_start_time(debug, enabled: bool) -> Optional[float]:
|
| 127 |
+
"""Get start time for swap operation if debug is enabled."""
|
| 128 |
+
return time.time() if debug and enabled else None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch._dynamo.disable
|
| 132 |
+
def _log_swap_timing(debug, t_start: Optional[float], component_id, component_type: str) -> None:
|
| 133 |
+
"""Log swap timing if start time was captured."""
|
| 134 |
+
if debug and t_start is not None:
|
| 135 |
+
debug.log_swap_time(
|
| 136 |
+
component_id=component_id,
|
| 137 |
+
duration=time.time() - t_start,
|
| 138 |
+
component_type=component_type
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_module_memory_mb(module: torch.nn.Module) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Calculate memory usage of a module in MB.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
module: PyTorch module to measure
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Memory usage in megabytes
|
| 151 |
+
"""
|
| 152 |
+
total_bytes = sum(
|
| 153 |
+
param.nelement() * param.element_size()
|
| 154 |
+
for param in module.parameters()
|
| 155 |
+
if param.data is not None
|
| 156 |
+
)
|
| 157 |
+
return total_bytes / (1024 * 1024)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def apply_block_swap_to_dit(
|
| 161 |
+
runner: 'VideoDiffusionInfer',
|
| 162 |
+
block_swap_config: Dict[str, Any],
|
| 163 |
+
debug: 'Debug'
|
| 164 |
+
) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Apply block swapping configuration to a DiT model with OOM protection.
|
| 167 |
+
|
| 168 |
+
This is the main entry point for configuring block swapping on a model.
|
| 169 |
+
Handles block selection, I/O component offloading, device placement, and
|
| 170 |
+
forward method wrapping for dynamic memory management.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
runner: VideoDiffusionInfer instance containing the model
|
| 174 |
+
block_swap_config: Configuration dictionary with keys:
|
| 175 |
+
- blocks_to_swap: Number of blocks to swap (from the start)
|
| 176 |
+
- swap_io_components: Whether to offload I/O components
|
| 177 |
+
- enable_debug: Whether to enable debug logging
|
| 178 |
+
- offload_device: Device to offload to (default: 'cpu')
|
| 179 |
+
debug: Debug instance for logging (required)
|
| 180 |
+
"""
|
| 181 |
+
# Early return if BlockSwap not enabled
|
| 182 |
+
if not is_blockswap_enabled(block_swap_config):
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
|
| 186 |
+
swap_io_components = block_swap_config.get("swap_io_components", False)
|
| 187 |
+
|
| 188 |
+
# Early return only if both block swap and I/O swap are disabled
|
| 189 |
+
if blocks_to_swap <= 0 and not swap_io_components:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
if debug is None:
|
| 193 |
+
if hasattr(runner, 'debug') and runner.debug is not None:
|
| 194 |
+
debug = runner.debug
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError("Debug instance must be provided to apply_block_swap_to_dit")
|
| 197 |
+
|
| 198 |
+
debug.start_timer("apply_blockswap")
|
| 199 |
+
|
| 200 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 201 |
+
model = runner.dit
|
| 202 |
+
if hasattr(model, "dit_model"):
|
| 203 |
+
model = model.dit_model
|
| 204 |
+
|
| 205 |
+
# Determine devices
|
| 206 |
+
if hasattr(runner, '_dit_device'):
|
| 207 |
+
device = runner._dit_device
|
| 208 |
+
else:
|
| 209 |
+
device = get_device()
|
| 210 |
+
offload_device = block_swap_config.get("offload_device", torch.device('cpu'))
|
| 211 |
+
|
| 212 |
+
# Validate model structure
|
| 213 |
+
if not hasattr(model, "blocks"):
|
| 214 |
+
debug.log("Model doesn't have 'blocks' attribute for BlockSwap", level="ERROR", category="blockswap", force=True)
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
total_blocks = len(model.blocks)
|
| 218 |
+
|
| 219 |
+
# Clamp blocks_to_swap to available blocks BEFORE logging
|
| 220 |
+
effective_blocks = min(blocks_to_swap, total_blocks) if blocks_to_swap > 0 else 0
|
| 221 |
+
|
| 222 |
+
# Log configuration clearly based on what's enabled
|
| 223 |
+
block_text = "block" if effective_blocks <= 1 else "blocks"
|
| 224 |
+
if effective_blocks > 0 and swap_io_components:
|
| 225 |
+
debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} + I/O components offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
|
| 226 |
+
elif effective_blocks > 0:
|
| 227 |
+
debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
|
| 228 |
+
elif swap_io_components:
|
| 229 |
+
debug.log(f"BlockSwap: I/O components offloaded to {str(offload_device).upper()} (0/{total_blocks} blocks swapped)", category="blockswap", force=True)
|
| 230 |
+
|
| 231 |
+
# Configure model with blockswap attributes
|
| 232 |
+
if blocks_to_swap > 0:
|
| 233 |
+
model.blocks_to_swap = effective_blocks - 1 # Convert to 0-indexed
|
| 234 |
+
else:
|
| 235 |
+
# No block swapping, set to -1 so no blocks match the swap condition
|
| 236 |
+
model.blocks_to_swap = -1
|
| 237 |
+
|
| 238 |
+
model.main_device = device
|
| 239 |
+
model.offload_device = offload_device
|
| 240 |
+
|
| 241 |
+
# Configure I/O components
|
| 242 |
+
io_config = _configure_io_components(model, device, offload_device,
|
| 243 |
+
swap_io_components, debug)
|
| 244 |
+
memory_stats = _configure_blocks(model, device, offload_device, debug)
|
| 245 |
+
memory_stats['io_components'] = io_config['components']
|
| 246 |
+
memory_stats['io_memory_mb'] = io_config['memory_mb']
|
| 247 |
+
memory_stats['gpu_components'] = io_config['gpu_components']
|
| 248 |
+
memory_stats['io_gpu_memory_mb'] = io_config['gpu_memory_mb']
|
| 249 |
+
|
| 250 |
+
# Log memory summary
|
| 251 |
+
_log_memory_summary(memory_stats, offload_device, device, swap_io_components,
|
| 252 |
+
debug)
|
| 253 |
+
|
| 254 |
+
# Wrap block forward methods for dynamic swapping (only if blocks_to_swap > 0)
|
| 255 |
+
if blocks_to_swap > 0:
|
| 256 |
+
for b, block in enumerate(model.blocks):
|
| 257 |
+
if b <= model.blocks_to_swap:
|
| 258 |
+
_wrap_block_forward(block, b, model, debug)
|
| 259 |
+
|
| 260 |
+
# Patch RoPE modules for robust error handling
|
| 261 |
+
_patch_rope_for_blockswap(model, debug)
|
| 262 |
+
|
| 263 |
+
# Mark BlockSwap as active
|
| 264 |
+
runner._blockswap_active = True
|
| 265 |
+
|
| 266 |
+
# Store configuration for debugging and cleanup
|
| 267 |
+
model._block_swap_config = {
|
| 268 |
+
"blocks_swapped": blocks_to_swap,
|
| 269 |
+
"swap_io_components": swap_io_components,
|
| 270 |
+
"total_blocks": total_blocks,
|
| 271 |
+
"offload_device": offload_device,
|
| 272 |
+
"main_device": device,
|
| 273 |
+
"offload_memory": memory_stats['offload_memory'],
|
| 274 |
+
"main_memory": memory_stats['main_memory']
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
# Protect model from being moved entirely
|
| 278 |
+
_protect_model_from_move(model, runner, debug)
|
| 279 |
+
|
| 280 |
+
debug.log("BlockSwap configuration complete", category="success")
|
| 281 |
+
debug.end_timer("apply_blockswap", "BlockSwap configuration application")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _configure_io_components(
|
| 285 |
+
model: torch.nn.Module,
|
| 286 |
+
device: torch.device,
|
| 287 |
+
offload_device: torch.device,
|
| 288 |
+
swap_io_components: bool,
|
| 289 |
+
debug: 'Debug'
|
| 290 |
+
) -> Dict[str, Any]:
|
| 291 |
+
"""
|
| 292 |
+
Configure I/O component placement and wrapping with memory tracking.
|
| 293 |
+
|
| 294 |
+
Handles all non-block modules (embeddings, normalization layers, etc.) by
|
| 295 |
+
either keeping them on GPU or offloading them with dynamic swapping wrappers.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
model: DiT model containing named children to configure
|
| 299 |
+
device: Main computation device (typically GPU)
|
| 300 |
+
offload_device: Device for offloaded components (typically CPU)
|
| 301 |
+
swap_io_components: If True, offload I/O components with dynamic swapping
|
| 302 |
+
debug: Debug instance for logging (required)
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
Dictionary containing:
|
| 306 |
+
- components: List of offloaded component names
|
| 307 |
+
- memory_mb: Total memory of offloaded components in MB
|
| 308 |
+
- gpu_components: List of components remaining on GPU
|
| 309 |
+
- gpu_memory_mb: Total memory of GPU components in MB
|
| 310 |
+
"""
|
| 311 |
+
io_components_offloaded = []
|
| 312 |
+
io_components_on_gpu = []
|
| 313 |
+
io_memory_mb = 0.0
|
| 314 |
+
io_gpu_memory_mb = 0.0
|
| 315 |
+
|
| 316 |
+
# Handle I/O modules with dynamic swapping
|
| 317 |
+
for name, module in model.named_children():
|
| 318 |
+
if name != "blocks":
|
| 319 |
+
module_memory = get_module_memory_mb(module)
|
| 320 |
+
|
| 321 |
+
if swap_io_components:
|
| 322 |
+
module.to(offload_device)
|
| 323 |
+
_wrap_io_forward(module, name, model, debug)
|
| 324 |
+
io_components_offloaded.append(name)
|
| 325 |
+
io_memory_mb += module_memory
|
| 326 |
+
debug.log(f"{name} → {str(offload_device).upper()} ({module_memory:.2f}MB, dynamic swapping)", category="blockswap", indent_level=1)
|
| 327 |
+
else:
|
| 328 |
+
module.to(device)
|
| 329 |
+
io_components_on_gpu.append(name)
|
| 330 |
+
io_gpu_memory_mb += module_memory
|
| 331 |
+
debug.log(f"{name} → {str(device).upper()} ({module_memory:.2f}MB)", category="blockswap", indent_level=1)
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
'components': io_components_offloaded,
|
| 335 |
+
'memory_mb': io_memory_mb,
|
| 336 |
+
'gpu_components': io_components_on_gpu,
|
| 337 |
+
'gpu_memory_mb': io_gpu_memory_mb
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _configure_blocks(
|
| 342 |
+
model: torch.nn.Module,
|
| 343 |
+
device: torch.device,
|
| 344 |
+
offload_device: torch.device,
|
| 345 |
+
debug: 'Debug'
|
| 346 |
+
) -> Dict[str, float]:
|
| 347 |
+
"""
|
| 348 |
+
Configure transformer block placement and calculate memory statistics.
|
| 349 |
+
|
| 350 |
+
Moves blocks to their designated devices based on model.blocks_to_swap
|
| 351 |
+
attribute. Blocks with index <= blocks_to_swap go to offload device,
|
| 352 |
+
others stay on main device.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
model: DiT model with blocks attribute and blocks_to_swap configured
|
| 356 |
+
device: Main computation device for non-swapped blocks
|
| 357 |
+
offload_device: Device for swapped blocks
|
| 358 |
+
debug: Debug instance for logging (required)
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Dictionary containing:
|
| 362 |
+
- offload_memory: Total memory of offloaded blocks in MB
|
| 363 |
+
- main_memory: Total memory of blocks on main device in MB
|
| 364 |
+
- io_components: Empty list (populated by caller)
|
| 365 |
+
"""
|
| 366 |
+
total_offload_memory = 0.0
|
| 367 |
+
total_main_memory = 0.0
|
| 368 |
+
|
| 369 |
+
# Move blocks based on swap configuration
|
| 370 |
+
for b, block in enumerate(model.blocks):
|
| 371 |
+
block_memory = get_module_memory_mb(block)
|
| 372 |
+
|
| 373 |
+
if b > model.blocks_to_swap:
|
| 374 |
+
block.to(device)
|
| 375 |
+
total_main_memory += block_memory
|
| 376 |
+
else:
|
| 377 |
+
block.to(offload_device, non_blocking=False)
|
| 378 |
+
total_offload_memory += block_memory
|
| 379 |
+
|
| 380 |
+
# Ensure all buffers match their containing module's device
|
| 381 |
+
for b, block in enumerate(model.blocks):
|
| 382 |
+
target_device = device if b > model.blocks_to_swap else offload_device
|
| 383 |
+
for name, buffer in block.named_buffers():
|
| 384 |
+
if buffer.device != torch.device(target_device):
|
| 385 |
+
buffer.data = buffer.data.to(target_device, non_blocking=False)
|
| 386 |
+
|
| 387 |
+
return {
|
| 388 |
+
"offload_memory": total_offload_memory,
|
| 389 |
+
"main_memory": total_main_memory,
|
| 390 |
+
"io_components": [] # Will be populated by caller
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _log_memory_summary(
|
| 395 |
+
memory_stats: Dict[str, float],
|
| 396 |
+
offload_device: torch.device,
|
| 397 |
+
device: torch.device,
|
| 398 |
+
swap_io_components: bool,
|
| 399 |
+
debug: 'Debug'
|
| 400 |
+
) -> None:
|
| 401 |
+
"""
|
| 402 |
+
Log comprehensive memory usage summary for BlockSwap configuration.
|
| 403 |
+
|
| 404 |
+
Displays detailed breakdown of memory distribution across devices,
|
| 405 |
+
including transformer blocks and I/O components.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
memory_stats: Dictionary containing:
|
| 409 |
+
- offload_memory: Memory offloaded from blocks (MB)
|
| 410 |
+
- main_memory: Memory remaining on main device (MB)
|
| 411 |
+
- io_memory_mb: Memory from offloaded I/O components (MB)
|
| 412 |
+
- io_gpu_memory_mb: Memory from I/O components on GPU (MB)
|
| 413 |
+
offload_device: Device used for offloading
|
| 414 |
+
device: Main computation device
|
| 415 |
+
swap_io_components: Whether I/O components are being swapped
|
| 416 |
+
debug: Debug instance for logging (required)
|
| 417 |
+
"""
|
| 418 |
+
debug.log("BlockSwap memory configuration:", category="blockswap")
|
| 419 |
+
|
| 420 |
+
# Log transformer blocks memory
|
| 421 |
+
blocks_offloaded = memory_stats['offload_memory']
|
| 422 |
+
blocks_on_gpu = memory_stats['main_memory']
|
| 423 |
+
|
| 424 |
+
offload_str = str(offload_device)
|
| 425 |
+
device_str = str(device)
|
| 426 |
+
|
| 427 |
+
if blocks_on_gpu == 0:
|
| 428 |
+
debug.log(f"Transformer blocks: {blocks_offloaded:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
|
| 429 |
+
else:
|
| 430 |
+
debug.log(f"Transformer blocks: {blocks_on_gpu:.2f}MB on {device_str}, {blocks_offloaded:.2f}MB on {offload_str}", category="blockswap", indent_level=1)
|
| 431 |
+
|
| 432 |
+
# Always log I/O components (whether swapping or not)
|
| 433 |
+
io_memory = memory_stats.get('io_memory_mb', 0.0)
|
| 434 |
+
io_gpu_memory = memory_stats.get('io_gpu_memory_mb', 0.0)
|
| 435 |
+
|
| 436 |
+
if swap_io_components and io_memory > 0:
|
| 437 |
+
io_components = memory_stats.get('io_components', [])
|
| 438 |
+
debug.log(f"I/O components: {io_memory:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
|
| 439 |
+
debug.log(f"{', '.join(io_components)}", category="blockswap", indent_level=2)
|
| 440 |
+
elif io_gpu_memory > 0:
|
| 441 |
+
io_gpu_components = memory_stats.get('gpu_components', [])
|
| 442 |
+
debug.log(f"I/O components: {io_gpu_memory:.2f}MB on {device_str}", category="blockswap", indent_level=1)
|
| 443 |
+
debug.log(f"{', '.join(io_gpu_components)}", category="blockswap", indent_level=2)
|
| 444 |
+
|
| 445 |
+
# Log total VRAM savings
|
| 446 |
+
total_offloaded = blocks_offloaded + (io_memory if swap_io_components else 0)
|
| 447 |
+
if total_offloaded > 0:
|
| 448 |
+
debug.log(f"Total VRAM saved: {total_offloaded:.2f}MB (~{total_offloaded/1024:.2f}GB)", category="blockswap", indent_level=1)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _wrap_block_forward(
|
| 452 |
+
block: torch.nn.Module,
|
| 453 |
+
block_idx: int,
|
| 454 |
+
model: torch.nn.Module,
|
| 455 |
+
debug: 'Debug'
|
| 456 |
+
) -> None:
|
| 457 |
+
"""
|
| 458 |
+
Wrap individual transformer block forward for dynamic device swapping.
|
| 459 |
+
|
| 460 |
+
Creates a wrapped forward method that automatically:
|
| 461 |
+
1. Moves block to GPU before computation
|
| 462 |
+
2. Executes original forward pass
|
| 463 |
+
3. Moves block back to offload device after computation
|
| 464 |
+
4. Logs timing and manages memory pressure
|
| 465 |
+
|
| 466 |
+
Uses weak references to prevent memory leaks from closure retention.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
block: Individual transformer block to wrap
|
| 470 |
+
block_idx: Index of this block in model.blocks
|
| 471 |
+
model: Parent DiT model (used for device references)
|
| 472 |
+
debug: Debug instance for logging (required)
|
| 473 |
+
"""
|
| 474 |
+
if hasattr(block, '_original_forward'):
|
| 475 |
+
return # Already wrapped
|
| 476 |
+
|
| 477 |
+
# Store original forward method
|
| 478 |
+
original_forward = block.forward
|
| 479 |
+
|
| 480 |
+
# Create weak references
|
| 481 |
+
model_ref = weakref.ref(model)
|
| 482 |
+
debug_ref = weakref.ref(debug)
|
| 483 |
+
|
| 484 |
+
# Store block_idx on the block itself to avoid closure issues
|
| 485 |
+
block._block_idx = block_idx
|
| 486 |
+
|
| 487 |
+
def wrapped_forward(self, *args, **kwargs):
|
| 488 |
+
# Retrieve weak references
|
| 489 |
+
model = model_ref()
|
| 490 |
+
debug = debug_ref()
|
| 491 |
+
|
| 492 |
+
if not model:
|
| 493 |
+
# Model has been garbage collected, fall back to original
|
| 494 |
+
return original_forward(*args, **kwargs)
|
| 495 |
+
|
| 496 |
+
# Check if block swap is active for this block
|
| 497 |
+
if hasattr(model, 'blocks_to_swap') and self._block_idx <= model.blocks_to_swap:
|
| 498 |
+
# Use dynamo-disabled helper to get start time (avoids compilation warnings)
|
| 499 |
+
t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
|
| 500 |
+
|
| 501 |
+
# Only move to GPU if necessary
|
| 502 |
+
current_device = next(self.parameters()).device
|
| 503 |
+
target_device = torch.device(model.main_device)
|
| 504 |
+
|
| 505 |
+
if current_device != target_device:
|
| 506 |
+
self.to(model.main_device, non_blocking=False)
|
| 507 |
+
|
| 508 |
+
# Execute forward pass with OOM protection
|
| 509 |
+
output = original_forward(*args, **kwargs)
|
| 510 |
+
|
| 511 |
+
# Move back to offload device
|
| 512 |
+
self.to(model.offload_device, non_blocking=False)
|
| 513 |
+
|
| 514 |
+
# Use dynamo-disabled helper to log timing (avoids compilation warnings)
|
| 515 |
+
_log_swap_timing(debug, t_start, self._block_idx, "block")
|
| 516 |
+
|
| 517 |
+
# Only clear cache under memory pressure
|
| 518 |
+
clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
|
| 519 |
+
else:
|
| 520 |
+
output = original_forward(*args, **kwargs)
|
| 521 |
+
|
| 522 |
+
return output
|
| 523 |
+
|
| 524 |
+
# Bind the wrapped function as a method to the block
|
| 525 |
+
block.forward = types.MethodType(wrapped_forward, block)
|
| 526 |
+
|
| 527 |
+
# Store reference to original forward for cleanup
|
| 528 |
+
block._original_forward = original_forward
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _wrap_io_forward(
|
| 532 |
+
module: torch.nn.Module,
|
| 533 |
+
module_name: str,
|
| 534 |
+
model: torch.nn.Module,
|
| 535 |
+
debug: 'Debug'
|
| 536 |
+
) -> None:
|
| 537 |
+
"""
|
| 538 |
+
Wrap I/O component forward for dynamic device swapping.
|
| 539 |
+
|
| 540 |
+
Similar to _wrap_block_forward but for I/O components (embeddings,
|
| 541 |
+
normalization layers, etc.). Handles swapping between GPU and CPU
|
| 542 |
+
during forward passes.
|
| 543 |
+
|
| 544 |
+
Uses weak references to prevent circular dependencies and memory leaks.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
module: I/O component module to wrap
|
| 548 |
+
module_name: Name identifier for logging (e.g., 'x_embedder')
|
| 549 |
+
model: Parent DiT model (used for device references)
|
| 550 |
+
debug: Debug instance for logging (required)
|
| 551 |
+
"""
|
| 552 |
+
if hasattr(module, '_is_io_wrapped') and module._is_io_wrapped:
|
| 553 |
+
debug.log(f"Reusing existing I/O wrapper for {module_name}", category="reuse")
|
| 554 |
+
return # Already wrapped
|
| 555 |
+
|
| 556 |
+
# Store original forward method
|
| 557 |
+
original_forward = module.forward
|
| 558 |
+
|
| 559 |
+
# Create weak references
|
| 560 |
+
model_ref = weakref.ref(model)
|
| 561 |
+
debug_ref = weakref.ref(debug) if debug else lambda: None
|
| 562 |
+
|
| 563 |
+
# Store module name on the module itself
|
| 564 |
+
module._module_name = module_name
|
| 565 |
+
module._original_forward = original_forward
|
| 566 |
+
|
| 567 |
+
def wrapped_io_forward(self, *args, **kwargs):
|
| 568 |
+
# Retrieve weak references
|
| 569 |
+
model = model_ref()
|
| 570 |
+
debug = debug_ref()
|
| 571 |
+
|
| 572 |
+
if not model:
|
| 573 |
+
# Model has been garbage collected, fall back to original
|
| 574 |
+
return self._original_forward(*args, **kwargs)
|
| 575 |
+
|
| 576 |
+
# Use dynamo-disabled helper to get start time (avoids compilation warnings)
|
| 577 |
+
t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
|
| 578 |
+
|
| 579 |
+
# Check current device to avoid unnecessary moves
|
| 580 |
+
current_device = next(self.parameters()).device
|
| 581 |
+
target_device = torch.device(model.main_device)
|
| 582 |
+
|
| 583 |
+
# Move to GPU for computation if needed
|
| 584 |
+
if current_device != target_device:
|
| 585 |
+
self.to(model.main_device, non_blocking=False)
|
| 586 |
+
|
| 587 |
+
# Execute forward pass
|
| 588 |
+
output = self._original_forward(*args, **kwargs)
|
| 589 |
+
|
| 590 |
+
# Move back to offload device
|
| 591 |
+
self.to(model.offload_device, non_blocking=False)
|
| 592 |
+
|
| 593 |
+
# Use dynamo-disabled helper to log timing (avoids compilation warnings)
|
| 594 |
+
_log_swap_timing(debug, t_start, self._module_name, "I/O")
|
| 595 |
+
|
| 596 |
+
# Only clear cache under memory pressure
|
| 597 |
+
clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
|
| 598 |
+
|
| 599 |
+
return output
|
| 600 |
+
|
| 601 |
+
# Bind as a method
|
| 602 |
+
module.forward = types.MethodType(wrapped_io_forward, module)
|
| 603 |
+
module._is_io_wrapped = True
|
| 604 |
+
|
| 605 |
+
# Store module reference for restoration
|
| 606 |
+
if not hasattr(model, '_io_swappers'):
|
| 607 |
+
model._io_swappers = []
|
| 608 |
+
model._io_swappers.append((module, module_name))
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def _patch_rope_for_blockswap(
|
| 612 |
+
model: torch.nn.Module,
|
| 613 |
+
debug: 'Debug'
|
| 614 |
+
) -> None:
|
| 615 |
+
"""
|
| 616 |
+
Patch RoPE (Rotary Position Embedding) modules for device-aware fallback.
|
| 617 |
+
|
| 618 |
+
Adds CPU fallback logic to RoPE modules to handle device mismatch errors
|
| 619 |
+
that can occur during BlockSwap operations. Complements the stability
|
| 620 |
+
wrapper from compatibility.py with device-specific error handling.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
model: DiT model containing RoPE modules to patch
|
| 624 |
+
debug: Debug instance for logging (required)
|
| 625 |
+
"""
|
| 626 |
+
rope_patches = []
|
| 627 |
+
|
| 628 |
+
for name, module in model.named_modules():
|
| 629 |
+
if "rope" in name.lower() and hasattr(module, "get_axial_freqs"):
|
| 630 |
+
# Skip if already wrapped by blockswap
|
| 631 |
+
if hasattr(module, '_blockswap_wrapped') and module._blockswap_wrapped:
|
| 632 |
+
continue
|
| 633 |
+
|
| 634 |
+
# Get current method (might be stability-wrapped)
|
| 635 |
+
current_method = module.get_axial_freqs
|
| 636 |
+
|
| 637 |
+
# Create device-aware wrapper with proper closure handling
|
| 638 |
+
def make_device_aware_wrapper(module_name, current_fn):
|
| 639 |
+
def device_aware_rope_wrapper(self, *args, **kwargs):
|
| 640 |
+
try:
|
| 641 |
+
# Try current method (original or stability-wrapped)
|
| 642 |
+
return current_fn(*args, **kwargs)
|
| 643 |
+
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
|
| 644 |
+
error_msg = str(e).lower()
|
| 645 |
+
# Only handle device/memory specific errors
|
| 646 |
+
if any(x in error_msg for x in ["device", "memory", "allocation"]):
|
| 647 |
+
debug.log(f"RoPE OOM for {module_name}", level="WARNING", category="rope", force=True)
|
| 648 |
+
debug.log(f"Clearing RoPE cache and retrying", category="info", force=True)
|
| 649 |
+
|
| 650 |
+
# Get current device from parameters
|
| 651 |
+
try:
|
| 652 |
+
current_device = next(self.parameters()).device
|
| 653 |
+
except StopIteration:
|
| 654 |
+
# Fallback: use model's main_device if BlockSwap has set it, else use offload_device
|
| 655 |
+
if hasattr(model, 'main_device'):
|
| 656 |
+
current_device = torch.device(model.main_device)
|
| 657 |
+
elif hasattr(model, 'offload_device'):
|
| 658 |
+
current_device = torch.device(model.offload_device)
|
| 659 |
+
|
| 660 |
+
# Try clearing cache first (non-invasive fix)
|
| 661 |
+
if hasattr(current_fn, 'cache_clear'):
|
| 662 |
+
current_fn.cache_clear()
|
| 663 |
+
try:
|
| 664 |
+
# Retry on same device after clearing cache
|
| 665 |
+
return current_fn(*args, **kwargs)
|
| 666 |
+
except Exception as retry_error:
|
| 667 |
+
# Cache clear wasn't enough, need more drastic measures
|
| 668 |
+
debug.log(f"Cache clear insufficient for {module_name}, falling back to CPU", level="WARNING", category="rope", force=True)
|
| 669 |
+
|
| 670 |
+
# Fallback to CPU computation with stability
|
| 671 |
+
self.cpu()
|
| 672 |
+
|
| 673 |
+
try:
|
| 674 |
+
# Use call_rope_with_stability for CPU computation
|
| 675 |
+
# This ensures cache is cleared and autocast disabled
|
| 676 |
+
original_fn = getattr(self, '_original_get_axial_freqs', current_fn)
|
| 677 |
+
result = call_rope_with_stability(original_fn, *args, **kwargs)
|
| 678 |
+
|
| 679 |
+
# Move module back to original device
|
| 680 |
+
self.to(current_device)
|
| 681 |
+
|
| 682 |
+
# Move result to appropriate device if it's a tensor
|
| 683 |
+
if hasattr(result, 'to'):
|
| 684 |
+
target_device = args[0].device if len(args) > 0 and hasattr(args[0], 'device') else current_device
|
| 685 |
+
return result.to(target_device)
|
| 686 |
+
return result
|
| 687 |
+
|
| 688 |
+
except Exception as cpu_error:
|
| 689 |
+
# Always restore device even on error
|
| 690 |
+
self.to(current_device)
|
| 691 |
+
raise cpu_error
|
| 692 |
+
else:
|
| 693 |
+
# Not a device error, let it bubble up
|
| 694 |
+
raise
|
| 695 |
+
|
| 696 |
+
return device_aware_rope_wrapper
|
| 697 |
+
|
| 698 |
+
# Apply wrapper
|
| 699 |
+
module.get_axial_freqs = types.MethodType(
|
| 700 |
+
make_device_aware_wrapper(name, current_method),
|
| 701 |
+
module
|
| 702 |
+
)
|
| 703 |
+
module._blockswap_wrapped = True
|
| 704 |
+
|
| 705 |
+
# Store for cleanup (use original or previously stored)
|
| 706 |
+
original_method = getattr(module, '_original_get_axial_freqs', current_method)
|
| 707 |
+
rope_patches.append((module, original_method))
|
| 708 |
+
|
| 709 |
+
if rope_patches:
|
| 710 |
+
model._rope_patches = rope_patches
|
| 711 |
+
debug.log(f"Patched {len(rope_patches)} RoPE modules with device handling", category="success")
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _protect_model_from_move(
|
| 715 |
+
model: torch.nn.Module,
|
| 716 |
+
runner: 'VideoDiffusionInfer',
|
| 717 |
+
debug: 'Debug'
|
| 718 |
+
) -> None:
|
| 719 |
+
"""
|
| 720 |
+
Protect model from unintended full device movement during BlockSwap.
|
| 721 |
+
|
| 722 |
+
Wraps model.to() method to prevent other code from accidentally moving
|
| 723 |
+
the entire model to GPU, which would defeat BlockSwap's memory savings.
|
| 724 |
+
Allows movement only when explicitly bypassed via model flag.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
model: DiT model to protect
|
| 728 |
+
runner: VideoDiffusionInfer instance (for active status check)
|
| 729 |
+
debug: Debug instance for logging (required)
|
| 730 |
+
"""
|
| 731 |
+
if not hasattr(model, '_original_to'):
|
| 732 |
+
# Store runner reference as weak reference to avoid circular refs
|
| 733 |
+
model._blockswap_runner_ref = weakref.ref(runner)
|
| 734 |
+
model._original_to = model.to
|
| 735 |
+
|
| 736 |
+
# Define the protected method without closures
|
| 737 |
+
def protected_model_to(self, device, *args, **kwargs):
|
| 738 |
+
# Check if protection is temporarily bypassed for offloading
|
| 739 |
+
# Flag is stored on model itself (not runner) to survive runner recreation
|
| 740 |
+
if getattr(self, "_blockswap_bypass_protection", False):
|
| 741 |
+
# Protection bypassed, allow movement
|
| 742 |
+
if hasattr(self, '_original_to'):
|
| 743 |
+
return self._original_to(device, *args, **kwargs)
|
| 744 |
+
|
| 745 |
+
# Get configured offload device directly from model
|
| 746 |
+
blockswap_offload_device = "cpu" # default
|
| 747 |
+
if hasattr(self, "_block_swap_config"):
|
| 748 |
+
blockswap_offload_device = self._block_swap_config.get("offload_device", "cpu")
|
| 749 |
+
|
| 750 |
+
# Check if BlockSwap is currently active via runner weak reference
|
| 751 |
+
runner_ref = getattr(self, '_blockswap_runner_ref', None)
|
| 752 |
+
blockswap_is_active = False
|
| 753 |
+
if runner_ref:
|
| 754 |
+
runner_obj = runner_ref()
|
| 755 |
+
if runner_obj and hasattr(runner_obj, "_blockswap_active"):
|
| 756 |
+
blockswap_is_active = runner_obj._blockswap_active
|
| 757 |
+
|
| 758 |
+
# Block attempts to move model away from configured offload device when active
|
| 759 |
+
if blockswap_is_active and str(device) != str(blockswap_offload_device):
|
| 760 |
+
# Get debug instance from runner if available
|
| 761 |
+
debug_instance = None
|
| 762 |
+
if runner_ref:
|
| 763 |
+
runner_obj = runner_ref()
|
| 764 |
+
if runner_obj and hasattr(runner_obj, 'debug'):
|
| 765 |
+
debug_instance = runner_obj.debug
|
| 766 |
+
|
| 767 |
+
if debug_instance:
|
| 768 |
+
debug_instance.log(
|
| 769 |
+
f"Blocked attempt to move BlockSwap model from {blockswap_offload_device} to {device}",
|
| 770 |
+
level="WARNING", category="blockswap", force=True
|
| 771 |
+
)
|
| 772 |
+
return self
|
| 773 |
+
|
| 774 |
+
# Allow movement (either bypass is enabled or target is offload device)
|
| 775 |
+
if hasattr(self, '_original_to'):
|
| 776 |
+
return self._original_to(device, *args, **kwargs)
|
| 777 |
+
else:
|
| 778 |
+
# Fallback - shouldn't happen
|
| 779 |
+
return super(type(self), self).to(device, *args, **kwargs)
|
| 780 |
+
|
| 781 |
+
# Bind as a method to the model instance
|
| 782 |
+
model.to = types.MethodType(protected_model_to, model)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def set_blockswap_bypass(runner, bypass: bool, debug):
|
| 786 |
+
"""
|
| 787 |
+
Set or unset bypass flag for BlockSwap protection.
|
| 788 |
+
Used for offloading to temporarily allow model movement.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
runner: Runner instance with BlockSwap
|
| 792 |
+
bypass: True to bypass protection, False to enforce it
|
| 793 |
+
debug: Debug instance for logging
|
| 794 |
+
"""
|
| 795 |
+
if not hasattr(runner, "_blockswap_active") or not runner._blockswap_active:
|
| 796 |
+
return
|
| 797 |
+
|
| 798 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 799 |
+
model = runner.dit
|
| 800 |
+
if hasattr(model, "dit_model"):
|
| 801 |
+
model = model.dit_model
|
| 802 |
+
|
| 803 |
+
# Store on model so it survives runner recreation during caching
|
| 804 |
+
model._blockswap_bypass_protection = bypass
|
| 805 |
+
|
| 806 |
+
if bypass:
|
| 807 |
+
debug.log("BlockSwap protection disabled to allow model DiT offloading", category="success")
|
| 808 |
+
else:
|
| 809 |
+
debug.log("BlockSwap protection renabled to avoid accidentally offloading the entire DiT model", category="success")
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def cleanup_blockswap(runner, keep_state_for_cache=False):
|
| 813 |
+
"""
|
| 814 |
+
Clean up BlockSwap configuration based on caching mode.
|
| 815 |
+
|
| 816 |
+
When caching (keep_state_for_cache=True):
|
| 817 |
+
- Keep all BlockSwap configuration intact
|
| 818 |
+
- Only mark as inactive for safety during non-inference operations
|
| 819 |
+
|
| 820 |
+
When not caching (keep_state_for_cache=False):
|
| 821 |
+
- Full cleanup of all BlockSwap state
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
runner: VideoDiffusionInfer instance to clean up
|
| 825 |
+
keep_state_for_cache: If True, preserve BlockSwap state for reuse
|
| 826 |
+
"""
|
| 827 |
+
# Get debug instance from runner
|
| 828 |
+
if not hasattr(runner, 'debug') or runner.debug is None:
|
| 829 |
+
raise ValueError("Debug instance must be available on runner for cleanup_blockswap")
|
| 830 |
+
|
| 831 |
+
debug = runner.debug
|
| 832 |
+
|
| 833 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 834 |
+
model = runner.dit
|
| 835 |
+
if hasattr(model, "dit_model"):
|
| 836 |
+
model = model.dit_model
|
| 837 |
+
|
| 838 |
+
# Check if there's any BlockSwap state to clean up (check both runner and model)
|
| 839 |
+
has_blockswap_state = (
|
| 840 |
+
hasattr(runner, "_blockswap_active") or
|
| 841 |
+
hasattr(model, "_block_swap_config") or
|
| 842 |
+
hasattr(model, "_blockswap_bypass_protection")
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
if not has_blockswap_state:
|
| 846 |
+
return
|
| 847 |
+
|
| 848 |
+
debug.log("Starting BlockSwap cleanup", category="cleanup")
|
| 849 |
+
|
| 850 |
+
if keep_state_for_cache:
|
| 851 |
+
# Minimal cleanup for caching - just mark as inactive and allow offloading
|
| 852 |
+
# Everything else stays intact for fast reactivation
|
| 853 |
+
if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
|
| 854 |
+
if not getattr(model, "_blockswap_bypass_protection", False):
|
| 855 |
+
set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
|
| 856 |
+
runner._blockswap_active = False
|
| 857 |
+
debug.log("BlockSwap deactivated for caching (configuration preserved)", category="success")
|
| 858 |
+
return
|
| 859 |
+
|
| 860 |
+
# Full cleanup when not caching
|
| 861 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 862 |
+
model = runner.dit
|
| 863 |
+
if hasattr(model, "dit_model"):
|
| 864 |
+
model = model.dit_model
|
| 865 |
+
|
| 866 |
+
# 1. Restore block forward methods
|
| 867 |
+
if hasattr(model, 'blocks'):
|
| 868 |
+
restored_count = 0
|
| 869 |
+
for block in model.blocks:
|
| 870 |
+
if hasattr(block, '_original_forward'):
|
| 871 |
+
block.forward = block._original_forward
|
| 872 |
+
delattr(block, '_original_forward')
|
| 873 |
+
restored_count += 1
|
| 874 |
+
|
| 875 |
+
# Clean up wrapper attributes
|
| 876 |
+
for attr in ['_block_idx', '_model_ref', '_debug_ref', '_blockswap_wrapped']:
|
| 877 |
+
if hasattr(block, attr):
|
| 878 |
+
delattr(block, attr)
|
| 879 |
+
|
| 880 |
+
if restored_count > 0:
|
| 881 |
+
debug.log(f"Restored {restored_count} block forward methods", category="success")
|
| 882 |
+
|
| 883 |
+
# 2. Restore RoPE patches
|
| 884 |
+
if hasattr(model, '_rope_patches'):
|
| 885 |
+
for module, original_method in model._rope_patches:
|
| 886 |
+
module.get_axial_freqs = original_method
|
| 887 |
+
# Clean up wrapper attributes
|
| 888 |
+
for attr in ['_rope_wrapped', '_original_get_axial_freqs']:
|
| 889 |
+
if hasattr(module, attr):
|
| 890 |
+
delattr(module, attr)
|
| 891 |
+
debug.log(f"Restored {len(model._rope_patches)} RoPE methods", category="success")
|
| 892 |
+
delattr(model, '_rope_patches')
|
| 893 |
+
|
| 894 |
+
# 3. Restore I/O component forward methods and move to offload device
|
| 895 |
+
if hasattr(model, '_io_swappers'):
|
| 896 |
+
for module, module_name in model._io_swappers:
|
| 897 |
+
if hasattr(module, '_original_forward'):
|
| 898 |
+
module.forward = module._original_forward
|
| 899 |
+
# Clean up wrapper attributes
|
| 900 |
+
for attr in ['_original_forward', '_model_ref', '_debug_ref',
|
| 901 |
+
'_module_name', '_is_io_wrapped']:
|
| 902 |
+
if hasattr(module, attr):
|
| 903 |
+
delattr(module, attr)
|
| 904 |
+
debug.log(f"Restored {len(model._io_swappers)} I/O components", category="success")
|
| 905 |
+
delattr(model, '_io_swappers')
|
| 906 |
+
|
| 907 |
+
# Move all IO components to offload device during full cleanup
|
| 908 |
+
if hasattr(model, 'offload_device'):
|
| 909 |
+
offload_device = model.offload_device
|
| 910 |
+
moved_count = 0
|
| 911 |
+
for name, module in model.named_children():
|
| 912 |
+
if name != "blocks":
|
| 913 |
+
module.to(offload_device)
|
| 914 |
+
moved_count += 1
|
| 915 |
+
if moved_count > 0:
|
| 916 |
+
debug.log(f"Moved {moved_count} IO components to offload device", category="success")
|
| 917 |
+
|
| 918 |
+
# 4. Restore original .to() method
|
| 919 |
+
if hasattr(model, '_original_to'):
|
| 920 |
+
model.to = model._original_to
|
| 921 |
+
delattr(model, '_original_to')
|
| 922 |
+
debug.log("Restored original .to() method", category="success")
|
| 923 |
+
|
| 924 |
+
# 5. Clean up BlockSwap-specific attributes
|
| 925 |
+
for attr in ['_blockswap_runner_ref', 'blocks_to_swap', 'main_device',
|
| 926 |
+
'offload_device']:
|
| 927 |
+
if hasattr(model, attr):
|
| 928 |
+
delattr(model, attr)
|
| 929 |
+
|
| 930 |
+
# 6. Clean up runner attributes
|
| 931 |
+
runner._blockswap_active = False
|
| 932 |
+
|
| 933 |
+
# Remove all config attributes
|
| 934 |
+
for attr in ['_cached_blockswap_config', '_block_swap_config', '_blockswap_debug']:
|
| 935 |
+
if hasattr(runner, attr):
|
| 936 |
+
delattr(runner, attr)
|
| 937 |
+
|
| 938 |
+
debug.log("BlockSwap cleanup complete", category="success")
|
src/optimization/memory_manager.py
ADDED
|
@@ -0,0 +1,1285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Memory management module for SeedVR2
|
| 3 |
+
Handles VRAM usage, cache management, and memory optimization
|
| 4 |
+
|
| 5 |
+
Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import gc
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
import psutil
|
| 13 |
+
import platform
|
| 14 |
+
from typing import Tuple, Dict, Any, Optional, List, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _device_str(device: Union[torch.device, str]) -> str:
|
| 18 |
+
"""Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
|
| 19 |
+
s = str(device).upper()
|
| 20 |
+
return 'MPS' if s.startswith('MPS') else s
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_mps_available() -> bool:
|
| 24 |
+
"""Check if MPS (Apple Metal) backend is available."""
|
| 25 |
+
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_cuda_available() -> bool:
|
| 29 |
+
"""Check if CUDA backend is available."""
|
| 30 |
+
return torch.cuda.is_available()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_gpu_backend() -> str:
|
| 34 |
+
"""Get the active GPU backend type.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
'cuda': NVIDIA CUDA
|
| 38 |
+
'mps': Apple Metal Performance Shaders
|
| 39 |
+
'cpu': No GPU backend available
|
| 40 |
+
"""
|
| 41 |
+
if is_cuda_available():
|
| 42 |
+
return 'cuda'
|
| 43 |
+
if is_mps_available():
|
| 44 |
+
return 'mps'
|
| 45 |
+
return 'cpu'
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Get list of available compute devices for SeedVR2
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
include_none: If True, prepend "none" to the device list (for offload options)
|
| 54 |
+
include_cpu: If True, include "cpu" in the device list (for offload options only)
|
| 55 |
+
Note: On MPS-only systems, "cpu" is automatically excluded since
|
| 56 |
+
unified memory architecture makes CPU offloading meaningless
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
List of device strings (e.g., ["cuda:0", "cuda:1"] or ["none", "cpu", "cuda:0", "cuda:1"])
|
| 60 |
+
"""
|
| 61 |
+
devs = []
|
| 62 |
+
has_cuda = False
|
| 63 |
+
has_mps = False
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
if is_cuda_available():
|
| 67 |
+
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
| 68 |
+
has_cuda = True
|
| 69 |
+
except Exception:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if is_mps_available():
|
| 74 |
+
devs.append("mps") # MPS doesn't use device indices
|
| 75 |
+
has_mps = True
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
# Build result list with optional prefixes
|
| 80 |
+
result = []
|
| 81 |
+
if include_none:
|
| 82 |
+
result.append("none")
|
| 83 |
+
|
| 84 |
+
# Only include "cpu" option if:
|
| 85 |
+
# 1. It was requested (include_cpu=True), AND
|
| 86 |
+
# 2. Either CUDA is available OR MPS is not the only option
|
| 87 |
+
# Rationale: On MPS-only systems with unified memory architecture,
|
| 88 |
+
# CPU offloading is semantically meaningless as CPU and GPU share the same memory pool
|
| 89 |
+
if include_cpu and (has_cuda or not has_mps):
|
| 90 |
+
result.append("cpu")
|
| 91 |
+
|
| 92 |
+
result.extend(devs)
|
| 93 |
+
|
| 94 |
+
return result if result else []
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]:
|
| 98 |
+
"""
|
| 99 |
+
Get basic VRAM availability info (free and total memory).
|
| 100 |
+
Used for capacity planning and initial checks.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
device: Optional device to query. If None, uses cuda:0
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
dict: {"free_gb": float, "total_gb": float} or {"error": str}
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
if is_cuda_available():
|
| 110 |
+
if device is None:
|
| 111 |
+
device = torch.device("cuda:0")
|
| 112 |
+
elif not isinstance(device, torch.device):
|
| 113 |
+
device = torch.device(device)
|
| 114 |
+
free_memory, total_memory = torch.cuda.mem_get_info(device)
|
| 115 |
+
elif is_mps_available():
|
| 116 |
+
# MPS doesn't support per-device queries or mem_get_info
|
| 117 |
+
# Use system memory as proxy
|
| 118 |
+
mem = psutil.virtual_memory()
|
| 119 |
+
free_memory = mem.total - mem.used
|
| 120 |
+
total_memory = mem.total
|
| 121 |
+
else:
|
| 122 |
+
return {"error": "No GPU backend available (CUDA/MPS)"}
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"free_gb": free_memory / (1024**3),
|
| 126 |
+
"total_gb": total_memory / (1024**3)
|
| 127 |
+
}
|
| 128 |
+
except Exception as e:
|
| 129 |
+
return {"error": f"Failed to get memory info: {str(e)}"}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Initial VRAM check at module load
|
| 133 |
+
vram_info = get_basic_vram_info(device=None)
|
| 134 |
+
if "error" not in vram_info:
|
| 135 |
+
backend = "MPS" if is_mps_available() else "CUDA"
|
| 136 |
+
print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total")
|
| 137 |
+
else:
|
| 138 |
+
print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
|
| 142 |
+
"""
|
| 143 |
+
Get current VRAM usage metrics for monitoring.
|
| 144 |
+
Used for tracking memory consumption during processing.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
device: Optional device to query. If None, uses cuda:0
|
| 148 |
+
debug: Optional debug instance for logging
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple: (allocated_gb, reserved_gb, peak_allocated_gb, peak_reserved_gb)
|
| 152 |
+
Returns (0, 0, 0, 0) if no GPU available
|
| 153 |
+
"""
|
| 154 |
+
try:
|
| 155 |
+
if is_cuda_available():
|
| 156 |
+
if device is None:
|
| 157 |
+
device = torch.device("cuda:0")
|
| 158 |
+
elif not isinstance(device, torch.device):
|
| 159 |
+
device = torch.device(device)
|
| 160 |
+
allocated = torch.cuda.memory_allocated(device) / (1024**3)
|
| 161 |
+
reserved = torch.cuda.memory_reserved(device) / (1024**3)
|
| 162 |
+
peak_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
|
| 163 |
+
peak_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
|
| 164 |
+
return allocated, reserved, peak_allocated, peak_reserved
|
| 165 |
+
elif is_mps_available():
|
| 166 |
+
# MPS doesn't support per-device queries - uses global memory tracking
|
| 167 |
+
allocated = torch.mps.current_allocated_memory() / (1024**3)
|
| 168 |
+
reserved = torch.mps.driver_allocated_memory() / (1024**3)
|
| 169 |
+
# MPS doesn't track peak separately
|
| 170 |
+
return allocated, reserved, allocated, reserved
|
| 171 |
+
except Exception as e:
|
| 172 |
+
if debug:
|
| 173 |
+
debug.log(f"Failed to get VRAM usage: {e}", level="WARNING", category="memory", force=True)
|
| 174 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_ram_usage(debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
|
| 178 |
+
"""
|
| 179 |
+
Get current RAM usage metrics for the current process.
|
| 180 |
+
Provides accurate tracking of process-specific memory consumption.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
debug: Optional debug instance for logging
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
tuple: (process_gb, available_gb, total_gb, used_by_others_gb)
|
| 187 |
+
Returns (0, 0, 0, 0) if psutil not available or on error
|
| 188 |
+
"""
|
| 189 |
+
try:
|
| 190 |
+
if not psutil:
|
| 191 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 192 |
+
|
| 193 |
+
# Get current process memory
|
| 194 |
+
process = psutil.Process()
|
| 195 |
+
process_memory = process.memory_info()
|
| 196 |
+
process_gb = process_memory.rss / (1024**3)
|
| 197 |
+
|
| 198 |
+
# Get system memory
|
| 199 |
+
sys_memory = psutil.virtual_memory()
|
| 200 |
+
total_gb = sys_memory.total / (1024**3)
|
| 201 |
+
available_gb = sys_memory.available / (1024**3)
|
| 202 |
+
|
| 203 |
+
# Calculate memory used by other processes
|
| 204 |
+
# This is the CORRECT calculation:
|
| 205 |
+
total_used_gb = total_gb - available_gb # Total memory used by ALL processes
|
| 206 |
+
used_by_others_gb = max(0, total_used_gb - process_gb) # Subtract current process
|
| 207 |
+
|
| 208 |
+
return process_gb, available_gb, total_gb, used_by_others_gb
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
if debug:
|
| 212 |
+
debug.log(f"Failed to get RAM usage: {e}", level="WARNING", category="memory", force=True)
|
| 213 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Global cache for OS libraries (initialized once)
|
| 217 |
+
_os_memory_lib = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: bool = True,
|
| 221 |
+
timer_name: Optional[str] = None) -> None:
|
| 222 |
+
"""
|
| 223 |
+
Clear memory caches with two-tier approach for optimal performance.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
debug: Debug instance for logging (optional)
|
| 227 |
+
force: If True, always clear. If False, only clear when <5% free
|
| 228 |
+
deep: If True, perform deep cleanup including GC and OS operations.
|
| 229 |
+
If False (default), only perform minimal GPU cache clearing.
|
| 230 |
+
timer_name: Optional suffix for timer names to make them unique per invocation
|
| 231 |
+
|
| 232 |
+
Two-tier approach:
|
| 233 |
+
- Minimal mode (deep=False): GPU cache operations (~1-5ms)
|
| 234 |
+
Used for frequent calls during batch processing
|
| 235 |
+
- Deep mode (deep=True): Complete cleanup with GC and OS operations (~10-50ms)
|
| 236 |
+
Used at key points like model switches or final cleanup
|
| 237 |
+
"""
|
| 238 |
+
global _os_memory_lib
|
| 239 |
+
|
| 240 |
+
# Create unique timer names if suffix provided
|
| 241 |
+
if timer_name:
|
| 242 |
+
main_timer = f"memory_clear_{timer_name}"
|
| 243 |
+
gpu_timer = f"gpu_cache_clear_{timer_name}"
|
| 244 |
+
gc_timer = f"garbage_collection_{timer_name}"
|
| 245 |
+
os_timer = f"os_memory_release_{timer_name}"
|
| 246 |
+
completion_msg = f"clear_memory() completion ({timer_name})"
|
| 247 |
+
else:
|
| 248 |
+
main_timer = "memory_clear"
|
| 249 |
+
gpu_timer = "gpu_cache_clear"
|
| 250 |
+
gc_timer = "garbage_collection"
|
| 251 |
+
os_timer = "os_memory_release"
|
| 252 |
+
completion_msg = "clear_memory() completion"
|
| 253 |
+
|
| 254 |
+
# Start timer for entire operation
|
| 255 |
+
if debug:
|
| 256 |
+
debug.start_timer(main_timer)
|
| 257 |
+
|
| 258 |
+
# Check if we should clear based on memory pressure
|
| 259 |
+
if not force:
|
| 260 |
+
should_clear = False
|
| 261 |
+
|
| 262 |
+
# Use existing function for memory info
|
| 263 |
+
mem_info = get_basic_vram_info(device=None)
|
| 264 |
+
|
| 265 |
+
if "error" not in mem_info and mem_info["total_gb"] > 0:
|
| 266 |
+
# Check VRAM/MPS memory pressure (5% free threshold)
|
| 267 |
+
free_ratio = mem_info["free_gb"] / mem_info["total_gb"]
|
| 268 |
+
if free_ratio < 0.05:
|
| 269 |
+
should_clear = True
|
| 270 |
+
if debug:
|
| 271 |
+
backend = "Unified Memory" if is_mps_available() else "VRAM"
|
| 272 |
+
debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory")
|
| 273 |
+
|
| 274 |
+
# For non-MPS systems, also check system RAM separately
|
| 275 |
+
if not should_clear and not is_mps_available():
|
| 276 |
+
mem = psutil.virtual_memory()
|
| 277 |
+
if mem.available < mem.total * 0.05:
|
| 278 |
+
should_clear = True
|
| 279 |
+
if debug:
|
| 280 |
+
debug.log(f"RAM pressure: {mem.available/(1024**3):.2f}GB free of {mem.total/(1024**3):.2f}GB", category="memory")
|
| 281 |
+
|
| 282 |
+
if not should_clear:
|
| 283 |
+
# End timer before early return to keep stack clean
|
| 284 |
+
if debug:
|
| 285 |
+
debug.end_timer(main_timer)
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
# Determine cleanup level
|
| 289 |
+
cleanup_mode = "deep" if deep else "minimal"
|
| 290 |
+
if debug:
|
| 291 |
+
debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup")
|
| 292 |
+
|
| 293 |
+
# ===== MINIMAL OPERATIONS (Always performed) =====
|
| 294 |
+
# Step 1: Clear GPU caches - Fast operations (~1-5ms)
|
| 295 |
+
if debug:
|
| 296 |
+
debug.start_timer(gpu_timer)
|
| 297 |
+
|
| 298 |
+
if is_cuda_available():
|
| 299 |
+
torch.cuda.empty_cache()
|
| 300 |
+
torch.cuda.ipc_collect()
|
| 301 |
+
elif is_mps_available():
|
| 302 |
+
torch.mps.empty_cache()
|
| 303 |
+
|
| 304 |
+
if debug:
|
| 305 |
+
debug.end_timer(gpu_timer, "GPU cache clearing")
|
| 306 |
+
|
| 307 |
+
# ===== DEEP OPERATIONS (Only when deep=True) =====
|
| 308 |
+
if deep:
|
| 309 |
+
# Step 2: Deep garbage collection (expensive ~5-20ms)
|
| 310 |
+
if debug:
|
| 311 |
+
debug.start_timer(gc_timer)
|
| 312 |
+
|
| 313 |
+
gc.collect(2)
|
| 314 |
+
|
| 315 |
+
if debug:
|
| 316 |
+
debug.end_timer(gc_timer, "Garbage collection")
|
| 317 |
+
|
| 318 |
+
# Step 3: Return memory to OS (platform-specific, ~5-30ms)
|
| 319 |
+
if debug:
|
| 320 |
+
debug.start_timer(os_timer)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
if sys.platform == 'linux':
|
| 324 |
+
# Linux: malloc_trim
|
| 325 |
+
import ctypes # Import only when needed
|
| 326 |
+
if _os_memory_lib is None:
|
| 327 |
+
_os_memory_lib = ctypes.CDLL("libc.so.6")
|
| 328 |
+
_os_memory_lib.malloc_trim(0)
|
| 329 |
+
|
| 330 |
+
elif sys.platform == 'win32':
|
| 331 |
+
# Windows: Trim working set
|
| 332 |
+
import ctypes # Import only when needed
|
| 333 |
+
if _os_memory_lib is None:
|
| 334 |
+
_os_memory_lib = ctypes.windll.kernel32
|
| 335 |
+
handle = _os_memory_lib.GetCurrentProcess()
|
| 336 |
+
_os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1)
|
| 337 |
+
|
| 338 |
+
elif is_mps_available():
|
| 339 |
+
# macOS with MPS
|
| 340 |
+
import ctypes # Import only when needed
|
| 341 |
+
import ctypes.util
|
| 342 |
+
if _os_memory_lib is None:
|
| 343 |
+
libc_path = ctypes.util.find_library('c')
|
| 344 |
+
if libc_path:
|
| 345 |
+
_os_memory_lib = ctypes.CDLL(libc_path)
|
| 346 |
+
|
| 347 |
+
if _os_memory_lib:
|
| 348 |
+
_os_memory_lib.sync()
|
| 349 |
+
except Exception as e:
|
| 350 |
+
if debug:
|
| 351 |
+
debug.log(f"Failed to perform OS memory operations: {e}", level="WARNING", category="memory", force=True)
|
| 352 |
+
|
| 353 |
+
if debug:
|
| 354 |
+
debug.end_timer(os_timer, "OS memory release")
|
| 355 |
+
|
| 356 |
+
# End overall timer
|
| 357 |
+
if debug:
|
| 358 |
+
debug.end_timer(main_timer, completion_msg)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def retry_on_oom(func, *args, debug=None, operation_name="operation", **kwargs):
|
| 362 |
+
"""
|
| 363 |
+
Execute function with single OOM retry after memory cleanup.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
func: Callable to execute
|
| 367 |
+
*args: Positional arguments for func
|
| 368 |
+
debug: Debug instance for logging (optional)
|
| 369 |
+
operation_name: Name for logging
|
| 370 |
+
**kwargs: Keyword arguments for func
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Result of func(*args, **kwargs)
|
| 374 |
+
"""
|
| 375 |
+
try:
|
| 376 |
+
return func(*args, **kwargs)
|
| 377 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| 378 |
+
# Only handle OOM errors
|
| 379 |
+
if not any(x in str(e).lower() for x in ["out of memory", "allocation on device"]):
|
| 380 |
+
raise
|
| 381 |
+
|
| 382 |
+
if debug:
|
| 383 |
+
debug.log(f"OOM during {operation_name}: {e}", level="WARNING", category="memory", force=True)
|
| 384 |
+
debug.log(f"Clearing memory and retrying", category="info", force=True)
|
| 385 |
+
|
| 386 |
+
# Clear memory
|
| 387 |
+
clear_memory(debug=debug, deep=True, force=True, timer_name=operation_name)
|
| 388 |
+
# Let memory settle
|
| 389 |
+
time.sleep(0.5)
|
| 390 |
+
debug.log_memory_state("After memory clearing", show_tensors=False, detailed_tensors=False)
|
| 391 |
+
|
| 392 |
+
# Single retry
|
| 393 |
+
try:
|
| 394 |
+
result = func(*args, **kwargs)
|
| 395 |
+
if debug:
|
| 396 |
+
debug.log(f"Retry successful for {operation_name}", category="success", force=True)
|
| 397 |
+
return result
|
| 398 |
+
except Exception as retry_e:
|
| 399 |
+
if debug:
|
| 400 |
+
debug.log(f"Retry failed for {operation_name}: {retry_e}", level="ERROR", category="memory", force=True)
|
| 401 |
+
raise
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def reset_vram_peak(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> None:
|
| 405 |
+
"""
|
| 406 |
+
Reset VRAM peak memory statistics for fresh tracking.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
device: Optional device to reset stats for. If None, uses cuda:0
|
| 410 |
+
debug: Optional debug instance for logging
|
| 411 |
+
"""
|
| 412 |
+
if debug and debug.enabled:
|
| 413 |
+
debug.log("Resetting VRAM peak memory statistics", category="memory")
|
| 414 |
+
try:
|
| 415 |
+
if is_cuda_available():
|
| 416 |
+
if device is None:
|
| 417 |
+
device = torch.device("cuda:0")
|
| 418 |
+
elif not isinstance(device, torch.device):
|
| 419 |
+
device = torch.device(device)
|
| 420 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 421 |
+
# Note: MPS doesn't support peak memory reset - no action needed
|
| 422 |
+
except Exception as e:
|
| 423 |
+
if debug and debug.enabled:
|
| 424 |
+
debug.log(f"Failed to reset peak memory stats: {e}", level="WARNING", category="memory", force=True)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> int:
|
| 428 |
+
"""
|
| 429 |
+
Clear ALL LRU caches from RoPE modules.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
model: PyTorch model to clear caches from
|
| 433 |
+
debug: Optional debug instance for logging
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
Number of caches cleared
|
| 437 |
+
"""
|
| 438 |
+
if model is None:
|
| 439 |
+
return 0
|
| 440 |
+
|
| 441 |
+
cleared_count = 0
|
| 442 |
+
try:
|
| 443 |
+
for name, module in model.named_modules():
|
| 444 |
+
if hasattr(module, 'get_axial_freqs') and hasattr(module.get_axial_freqs, 'cache_clear'):
|
| 445 |
+
try:
|
| 446 |
+
module.get_axial_freqs.cache_clear()
|
| 447 |
+
cleared_count += 1
|
| 448 |
+
except Exception as e:
|
| 449 |
+
if debug:
|
| 450 |
+
debug.log(f"Failed to clear RoPE LRU cache for module {name}: {e}", level="WARNING", category="memory", force=True)
|
| 451 |
+
except (AttributeError, RuntimeError) as e:
|
| 452 |
+
if debug:
|
| 453 |
+
debug.log(f"Failed to iterate model modules for RoPE LRU cache clearing: {e}", level="WARNING", category="memory", force=True)
|
| 454 |
+
|
| 455 |
+
return cleared_count
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None:
|
| 459 |
+
"""Release tensor memory from any device (CPU/CUDA/MPS)"""
|
| 460 |
+
if tensor is not None and torch.is_tensor(tensor):
|
| 461 |
+
# Release storage for all devices (CPU, CUDA, MPS)
|
| 462 |
+
if tensor.numel() > 0:
|
| 463 |
+
tensor.data.set_()
|
| 464 |
+
tensor.grad = None
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def release_tensor_collection(collection: Any, recursive: bool = True) -> None:
|
| 468 |
+
"""
|
| 469 |
+
Release GPU memory from tensors in any collection (list, tuple, dict, or single tensor).
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
collection: Tensor, list, tuple, dict, or nested structure to release
|
| 473 |
+
recursive: If True, handle nested structures recursively
|
| 474 |
+
|
| 475 |
+
Examples:
|
| 476 |
+
release_tensor_collection(tensor) # Single tensor
|
| 477 |
+
release_tensor_collection([tensor1, tensor2]) # List of tensors
|
| 478 |
+
release_tensor_collection([[t1, t2], [t3, t4]]) # Nested lists
|
| 479 |
+
release_tensor_collection({'a': tensor}) # Dict values
|
| 480 |
+
"""
|
| 481 |
+
if collection is None:
|
| 482 |
+
return
|
| 483 |
+
|
| 484 |
+
if torch.is_tensor(collection):
|
| 485 |
+
release_tensor_memory(collection)
|
| 486 |
+
elif isinstance(collection, dict):
|
| 487 |
+
for value in collection.values():
|
| 488 |
+
if recursive:
|
| 489 |
+
release_tensor_collection(value, recursive=True)
|
| 490 |
+
elif torch.is_tensor(value):
|
| 491 |
+
release_tensor_memory(value)
|
| 492 |
+
elif isinstance(collection, (list, tuple)):
|
| 493 |
+
for item in collection:
|
| 494 |
+
if recursive:
|
| 495 |
+
release_tensor_collection(item, recursive=True)
|
| 496 |
+
elif torch.is_tensor(item):
|
| 497 |
+
release_tensor_memory(item)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def release_text_embeddings(*embeddings: torch.Tensor, debug: Optional['Debug'] = None, names: Optional[List[str]] = None) -> None:
|
| 501 |
+
"""
|
| 502 |
+
Release memory for text embeddings
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
*embeddings: Variable number of embedding tensors to release
|
| 506 |
+
debug: Optional debug instance for logging
|
| 507 |
+
names: Optional list of names for logging
|
| 508 |
+
"""
|
| 509 |
+
for i, embedding in enumerate(embeddings):
|
| 510 |
+
if embedding is not None:
|
| 511 |
+
release_tensor_memory(embedding)
|
| 512 |
+
if debug and names and i < len(names):
|
| 513 |
+
debug.log(f"Cleaned up {names[i]}", category="cleanup")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None) -> None:
|
| 517 |
+
"""
|
| 518 |
+
Clean up text embeddings from a context dictionary.
|
| 519 |
+
Extracts embeddings, releases memory, and clears the context entry.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
ctx: Context dictionary potentially containing 'text_embeds'
|
| 523 |
+
debug: Optional debug instance for logging
|
| 524 |
+
"""
|
| 525 |
+
if not ctx or not ctx.get('text_embeds'):
|
| 526 |
+
return
|
| 527 |
+
|
| 528 |
+
embeddings = []
|
| 529 |
+
names = []
|
| 530 |
+
for key, embeds_list in ctx['text_embeds'].items():
|
| 531 |
+
if embeds_list:
|
| 532 |
+
embeddings.extend(embeds_list)
|
| 533 |
+
names.append(key)
|
| 534 |
+
|
| 535 |
+
if embeddings:
|
| 536 |
+
release_text_embeddings(embeddings, names, debug)
|
| 537 |
+
|
| 538 |
+
if debug:
|
| 539 |
+
debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup")
|
| 540 |
+
|
| 541 |
+
ctx['text_embeds'] = None
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None:
|
| 545 |
+
"""
|
| 546 |
+
Release all GPU/MPS memory from model in-place without CPU transfer.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
model: PyTorch model to release memory from
|
| 550 |
+
debug: Optional debug instance for logging
|
| 551 |
+
"""
|
| 552 |
+
if model is None:
|
| 553 |
+
return
|
| 554 |
+
|
| 555 |
+
# If the model has pipelining resources (swap stream), synchronize to ensure no pending async ops
|
| 556 |
+
try:
|
| 557 |
+
if hasattr(model, "_swap_stream"):
|
| 558 |
+
try:
|
| 559 |
+
model._swap_stream.synchronize()
|
| 560 |
+
except Exception:
|
| 561 |
+
if debug:
|
| 562 |
+
debug.log("Failed to synchronize model._swap_stream before releasing memory", level="WARNING", category="memory", force=True)
|
| 563 |
+
except Exception:
|
| 564 |
+
pass
|
| 565 |
+
|
| 566 |
+
try:
|
| 567 |
+
# Clear gradients first
|
| 568 |
+
model.zero_grad(set_to_none=True)
|
| 569 |
+
|
| 570 |
+
# Release GPU memory directly without CPU transfer
|
| 571 |
+
released_params = 0
|
| 572 |
+
released_buffers = 0
|
| 573 |
+
|
| 574 |
+
for param in model.parameters():
|
| 575 |
+
if param.is_cuda or param.is_mps:
|
| 576 |
+
if param.numel() > 0:
|
| 577 |
+
param.data.set_()
|
| 578 |
+
released_params += 1
|
| 579 |
+
param.grad = None
|
| 580 |
+
|
| 581 |
+
for buffer in model.buffers():
|
| 582 |
+
if buffer.is_cuda or buffer.is_mps:
|
| 583 |
+
if buffer.numel() > 0:
|
| 584 |
+
buffer.data.set_()
|
| 585 |
+
released_buffers += 1
|
| 586 |
+
|
| 587 |
+
if debug and (released_params > 0 or released_buffers > 0):
|
| 588 |
+
debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success")
|
| 589 |
+
|
| 590 |
+
except (AttributeError, RuntimeError) as e:
|
| 591 |
+
if debug:
|
| 592 |
+
debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def manage_tensor(
|
| 596 |
+
tensor: torch.Tensor,
|
| 597 |
+
target_device: torch.device,
|
| 598 |
+
tensor_name: str = "tensor",
|
| 599 |
+
dtype: Optional[torch.dtype] = None,
|
| 600 |
+
non_blocking: bool = False,
|
| 601 |
+
debug: Optional['Debug'] = None,
|
| 602 |
+
reason: Optional[str] = None,
|
| 603 |
+
indent_level: int = 0
|
| 604 |
+
) -> torch.Tensor:
|
| 605 |
+
"""
|
| 606 |
+
Unified tensor management for device movement and dtype conversion.
|
| 607 |
+
|
| 608 |
+
Handles both device transfers (CPU ↔ GPU) and dtype conversions (e.g., float16 → bfloat16)
|
| 609 |
+
with intelligent early-exit optimization and comprehensive logging.
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
tensor: Tensor to manage
|
| 613 |
+
target_device: Target device (torch.device object)
|
| 614 |
+
tensor_name: Descriptive name for logging (e.g., "latent", "sample", "alpha_channel")
|
| 615 |
+
dtype: Optional target dtype to cast to (if None, keeps original dtype)
|
| 616 |
+
non_blocking: Whether to use non-blocking transfer
|
| 617 |
+
debug: Debug instance for logging
|
| 618 |
+
reason: Optional reason for the operation (e.g., "inference", "offload", "dtype alignment")
|
| 619 |
+
indent_level: Indentation level for debug logging (0=no indent, 1=2 spaces, etc.)
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
Tensor on target device with optional dtype conversion
|
| 623 |
+
|
| 624 |
+
Note:
|
| 625 |
+
- Skips operation if tensor already has target device and dtype (zero-copy)
|
| 626 |
+
- Uses PyTorch's optimized .to() for efficient device/dtype handling
|
| 627 |
+
- Logs all operations consistently for tracking and debugging
|
| 628 |
+
"""
|
| 629 |
+
if tensor is None:
|
| 630 |
+
return tensor
|
| 631 |
+
|
| 632 |
+
# Get current state
|
| 633 |
+
current_device = tensor.device
|
| 634 |
+
current_dtype = tensor.dtype
|
| 635 |
+
target_dtype = dtype if dtype is not None else current_dtype
|
| 636 |
+
|
| 637 |
+
# Check if movement is actually needed
|
| 638 |
+
needs_device_move = _device_str(current_device) != _device_str(target_device)
|
| 639 |
+
needs_dtype_change = dtype is not None and current_dtype != target_dtype
|
| 640 |
+
|
| 641 |
+
if not needs_device_move and not needs_dtype_change:
|
| 642 |
+
# Already on target device and dtype - skip
|
| 643 |
+
return tensor
|
| 644 |
+
|
| 645 |
+
# Determine reason for movement
|
| 646 |
+
if reason is None:
|
| 647 |
+
if needs_device_move and needs_dtype_change:
|
| 648 |
+
reason = "device and dtype conversion"
|
| 649 |
+
elif needs_device_move:
|
| 650 |
+
reason = "device movement"
|
| 651 |
+
else:
|
| 652 |
+
reason = "dtype conversion"
|
| 653 |
+
|
| 654 |
+
# Log the movement
|
| 655 |
+
if debug:
|
| 656 |
+
current_device_str = _device_str(current_device)
|
| 657 |
+
target_device_str = _device_str(target_device)
|
| 658 |
+
|
| 659 |
+
dtype_info = ""
|
| 660 |
+
if needs_dtype_change:
|
| 661 |
+
dtype_info = f", {current_dtype} → {target_dtype}"
|
| 662 |
+
|
| 663 |
+
debug.log(
|
| 664 |
+
f"Moving {tensor_name} from {current_device_str} to {target_device_str}{dtype_info} ({reason})",
|
| 665 |
+
category="general",
|
| 666 |
+
indent_level=indent_level
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# Perform the operation based on what needs to change
|
| 670 |
+
if needs_device_move and needs_dtype_change:
|
| 671 |
+
# Both device and dtype need to change
|
| 672 |
+
return tensor.to(target_device, dtype=target_dtype, non_blocking=non_blocking)
|
| 673 |
+
elif needs_device_move:
|
| 674 |
+
# Only device needs to change
|
| 675 |
+
return tensor.to(target_device, non_blocking=non_blocking)
|
| 676 |
+
else:
|
| 677 |
+
# Only dtype needs to change
|
| 678 |
+
return tensor.to(dtype=target_dtype)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def manage_model_device(model: torch.nn.Module, target_device: torch.device, model_name: str,
|
| 682 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None,
|
| 683 |
+
runner: Optional[Any] = None) -> bool:
|
| 684 |
+
"""
|
| 685 |
+
Move model to target device with optimizations.
|
| 686 |
+
Handles BlockSwap-enabled models transparently.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
model: The model to move
|
| 690 |
+
target_device: Target device (torch.device object, e.g., torch.device('cuda:0'))
|
| 691 |
+
model_name: Name for logging (e.g., "VAE", "DiT")
|
| 692 |
+
debug: Debug instance for logging
|
| 693 |
+
reason: Optional custom reason for the movement
|
| 694 |
+
runner: Optional runner instance for BlockSwap detection
|
| 695 |
+
|
| 696 |
+
Returns:
|
| 697 |
+
bool: True if model was moved, False if already on target device
|
| 698 |
+
"""
|
| 699 |
+
if model is None:
|
| 700 |
+
return False
|
| 701 |
+
|
| 702 |
+
# Check if this is a BlockSwap-enabled DiT model
|
| 703 |
+
is_blockswap_model = False
|
| 704 |
+
actual_model = model
|
| 705 |
+
if runner and model_name == "DiT":
|
| 706 |
+
# Import here to avoid circular dependency
|
| 707 |
+
from .blockswap import is_blockswap_enabled
|
| 708 |
+
# Check if BlockSwap config exists and is enabled
|
| 709 |
+
has_blockswap_config = (
|
| 710 |
+
hasattr(runner, '_dit_block_swap_config') and
|
| 711 |
+
is_blockswap_enabled(runner._dit_block_swap_config)
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
if has_blockswap_config:
|
| 715 |
+
is_blockswap_model = True
|
| 716 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 717 |
+
if hasattr(model, "dit_model"):
|
| 718 |
+
actual_model = model.dit_model
|
| 719 |
+
|
| 720 |
+
# Get current device
|
| 721 |
+
try:
|
| 722 |
+
current_device = next(model.parameters()).device
|
| 723 |
+
except StopIteration:
|
| 724 |
+
return False
|
| 725 |
+
|
| 726 |
+
# Extract device type for comparison (both are torch.device objects)
|
| 727 |
+
target_type = target_device.type
|
| 728 |
+
current_device_upper = _device_str(current_device)
|
| 729 |
+
target_device_upper = _device_str(target_device)
|
| 730 |
+
|
| 731 |
+
# Compare normalized device types
|
| 732 |
+
if current_device_upper == target_device_upper and not is_blockswap_model:
|
| 733 |
+
# Already on target device type, no movement needed
|
| 734 |
+
if debug:
|
| 735 |
+
debug.log(f"{model_name} already on {current_device_upper}, skipping movement", category="general")
|
| 736 |
+
return False
|
| 737 |
+
|
| 738 |
+
# Handle BlockSwap models specially
|
| 739 |
+
if is_blockswap_model:
|
| 740 |
+
return _handle_blockswap_model_movement(
|
| 741 |
+
runner, actual_model, current_device, target_device, target_type,
|
| 742 |
+
model_name, debug, reason
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# Standard model movement (non-BlockSwap)
|
| 746 |
+
return _standard_model_movement(
|
| 747 |
+
model, current_device, target_device, target_type, model_name,
|
| 748 |
+
debug, reason
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
|
| 753 |
+
current_device: torch.device, target_device: torch.device,
|
| 754 |
+
target_type: str, model_name: str,
|
| 755 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
|
| 756 |
+
"""
|
| 757 |
+
Handle device movement for BlockSwap-enabled models.
|
| 758 |
+
|
| 759 |
+
Args:
|
| 760 |
+
runner: Runner instance with BlockSwap configuration
|
| 761 |
+
model: Model to move (actual unwrapped model)
|
| 762 |
+
current_device: Current device of the model
|
| 763 |
+
target_device: Target device (torch.device object)
|
| 764 |
+
target_type: Target device type (cpu/cuda/mps)
|
| 765 |
+
model_name: Model name for logging
|
| 766 |
+
debug: Debug instance
|
| 767 |
+
reason: Movement reason
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
bool: True if model was moved
|
| 771 |
+
"""
|
| 772 |
+
# Import here to avoid circular dependency
|
| 773 |
+
from .blockswap import set_blockswap_bypass
|
| 774 |
+
|
| 775 |
+
if target_type == "cpu":
|
| 776 |
+
# Moving to offload device (typically CPU)
|
| 777 |
+
# Check if any parameter is on GPU (for accurate logging)
|
| 778 |
+
actual_source_device = None
|
| 779 |
+
for param in model.parameters():
|
| 780 |
+
if param.device.type in ['cuda', 'mps']:
|
| 781 |
+
actual_source_device = param.device
|
| 782 |
+
break
|
| 783 |
+
|
| 784 |
+
source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device)
|
| 785 |
+
|
| 786 |
+
if debug:
|
| 787 |
+
debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general")
|
| 788 |
+
|
| 789 |
+
# Enable bypass to allow movement
|
| 790 |
+
set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
|
| 791 |
+
|
| 792 |
+
# If a pipelined swap stream exists, synchronize it to ensure no pending async transfers
|
| 793 |
+
if hasattr(model, "_swap_stream"):
|
| 794 |
+
try:
|
| 795 |
+
model._swap_stream.synchronize()
|
| 796 |
+
except Exception:
|
| 797 |
+
# Best-effort; don't fail the movement if synchronize not supported
|
| 798 |
+
if debug:
|
| 799 |
+
debug.log("Failed to synchronize model._swap_stream before offload", level="WARNING", category="memory", force=True)
|
| 800 |
+
|
| 801 |
+
# Start timer
|
| 802 |
+
timer_name = f"{model_name.lower()}_to_{target_type}"
|
| 803 |
+
if debug:
|
| 804 |
+
debug.start_timer(timer_name)
|
| 805 |
+
|
| 806 |
+
# Move entire model to target offload device
|
| 807 |
+
model.to(target_device)
|
| 808 |
+
model.zero_grad(set_to_none=True)
|
| 809 |
+
|
| 810 |
+
# After moving to CPU, attempt to pin CPU tensors to enable non-blocking async copies later.
|
| 811 |
+
try:
|
| 812 |
+
for p in model.parameters():
|
| 813 |
+
if p.device.type == "cpu" and p.numel() > 0 and not p.data.is_pinned():
|
| 814 |
+
p.data = p.data.pin_memory()
|
| 815 |
+
for b in model.buffers():
|
| 816 |
+
if b.device.type == "cpu" and b.numel() > 0 and not b.data.is_pinned():
|
| 817 |
+
b.data = b.data.pin_memory()
|
| 818 |
+
except Exception as e:
|
| 819 |
+
# Pinning is best-effort; log and continue
|
| 820 |
+
if debug:
|
| 821 |
+
debug.log(f"Pin-memory on offloaded model failed: {e}", level="WARNING", category="memory", force=True)
|
| 822 |
+
|
| 823 |
+
if debug:
|
| 824 |
+
debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
|
| 825 |
+
|
| 826 |
+
# Move entire model to target offload device
|
| 827 |
+
model.to(target_device)
|
| 828 |
+
model.zero_grad(set_to_none=True)
|
| 829 |
+
|
| 830 |
+
if debug:
|
| 831 |
+
debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
|
| 832 |
+
|
| 833 |
+
return True
|
| 834 |
+
|
| 835 |
+
else:
|
| 836 |
+
# Moving to GPU (reload)
|
| 837 |
+
# Check if we're in bypass mode (coming from offload)
|
| 838 |
+
if not getattr(model, "_blockswap_bypass_protection", False):
|
| 839 |
+
# Not in bypass mode, blocks are already configured
|
| 840 |
+
if debug:
|
| 841 |
+
debug.log(f"{model_name} with BlockSwap active - blocks already distributed across devices, skipping movement", category="general")
|
| 842 |
+
return False
|
| 843 |
+
|
| 844 |
+
# Get actual current device for accurate logging
|
| 845 |
+
actual_current_device = None
|
| 846 |
+
for param in model.parameters():
|
| 847 |
+
if param.device.type != 'meta':
|
| 848 |
+
actual_current_device = param.device
|
| 849 |
+
break
|
| 850 |
+
|
| 851 |
+
current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD"
|
| 852 |
+
|
| 853 |
+
if debug:
|
| 854 |
+
debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general")
|
| 855 |
+
|
| 856 |
+
timer_name = f"{model_name.lower()}_to_gpu"
|
| 857 |
+
if debug:
|
| 858 |
+
debug.start_timer(timer_name)
|
| 859 |
+
|
| 860 |
+
# Restore blocks to their configured devices
|
| 861 |
+
if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"):
|
| 862 |
+
# Use configured offload_device from BlockSwap config
|
| 863 |
+
offload_device = model._block_swap_config.get("offload_device")
|
| 864 |
+
if not offload_device:
|
| 865 |
+
raise ValueError("BlockSwap config missing offload_device")
|
| 866 |
+
|
| 867 |
+
# Move blocks according to BlockSwap configuration
|
| 868 |
+
for b, block in enumerate(model.blocks):
|
| 869 |
+
if b > model.blocks_to_swap:
|
| 870 |
+
# This block should be on GPU
|
| 871 |
+
block.to(target_device)
|
| 872 |
+
else:
|
| 873 |
+
# This block stays on offload device (will be swapped during forward)
|
| 874 |
+
block.to(offload_device)
|
| 875 |
+
|
| 876 |
+
# Handle I/O components
|
| 877 |
+
if not model._block_swap_config.get("swap_io_components", False):
|
| 878 |
+
# I/O components should be on GPU if not offloaded
|
| 879 |
+
for name, module in model.named_children():
|
| 880 |
+
if name != "blocks":
|
| 881 |
+
module.to(target_device)
|
| 882 |
+
else:
|
| 883 |
+
# I/O components stay on offload device
|
| 884 |
+
for name, module in model.named_children():
|
| 885 |
+
if name != "blocks":
|
| 886 |
+
module.to(offload_device)
|
| 887 |
+
|
| 888 |
+
if debug:
|
| 889 |
+
# Get actual configuration from runner
|
| 890 |
+
if hasattr(model, '_block_swap_config'):
|
| 891 |
+
blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16)
|
| 892 |
+
total_blocks = model._block_swap_config.get('total_blocks', 32)
|
| 893 |
+
main_device = model._block_swap_config.get('main_device', 'GPU')
|
| 894 |
+
debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success")
|
| 895 |
+
else:
|
| 896 |
+
debug.log("BlockSwap blocks restored to configured devices", category="success")
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
# Reactivate BlockSwap now that blocks are restored to their configured devices
|
| 900 |
+
runner._blockswap_active = True
|
| 901 |
+
|
| 902 |
+
# Disable bypass, re-enable protection
|
| 903 |
+
set_blockswap_bypass(runner=runner, bypass=False, debug=debug)
|
| 904 |
+
|
| 905 |
+
if debug:
|
| 906 |
+
debug.end_timer(timer_name, "BlockSwap model restored")
|
| 907 |
+
|
| 908 |
+
return True
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def _standard_model_movement(model: torch.nn.Module, current_device: torch.device,
|
| 912 |
+
target_device: torch.device, target_type: str, model_name: str,
|
| 913 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
|
| 914 |
+
"""
|
| 915 |
+
Handle standard (non-BlockSwap) model movement.
|
| 916 |
+
|
| 917 |
+
Args:
|
| 918 |
+
model: Model to move
|
| 919 |
+
current_device: Current device of the model
|
| 920 |
+
target_device: Target device (torch.device object)
|
| 921 |
+
target_type: Target device type
|
| 922 |
+
model_name: Model name for logging
|
| 923 |
+
debug: Debug instance
|
| 924 |
+
reason: Movement reason
|
| 925 |
+
|
| 926 |
+
Returns:
|
| 927 |
+
bool: True if model was moved
|
| 928 |
+
"""
|
| 929 |
+
# Check if model is on meta device - can't move meta tensors
|
| 930 |
+
if current_device.type == 'meta':
|
| 931 |
+
if debug:
|
| 932 |
+
debug.log(f"{model_name} is on meta device - skipping movement (will materialize when needed)",
|
| 933 |
+
category=model_name.lower())
|
| 934 |
+
return False
|
| 935 |
+
|
| 936 |
+
# Determine reason for movement
|
| 937 |
+
reason = reason or "inference requirement"
|
| 938 |
+
|
| 939 |
+
# Log the movement with full device strings
|
| 940 |
+
if debug:
|
| 941 |
+
current_device_str = _device_str(current_device)
|
| 942 |
+
target_device_str = _device_str(target_device)
|
| 943 |
+
debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general")
|
| 944 |
+
|
| 945 |
+
# Start timer based on direction
|
| 946 |
+
timer_name = f"{model_name.lower()}_to_{'gpu' if target_type != 'cpu' else 'cpu'}"
|
| 947 |
+
if debug:
|
| 948 |
+
debug.start_timer(timer_name)
|
| 949 |
+
|
| 950 |
+
# Move model and clear gradients
|
| 951 |
+
model.to(target_device)
|
| 952 |
+
model.zero_grad(set_to_none=True)
|
| 953 |
+
|
| 954 |
+
# Clear VAE memory buffers when moving to CPU
|
| 955 |
+
if target_type == 'cpu' and model_name == "VAE":
|
| 956 |
+
cleared_count = 0
|
| 957 |
+
for module in model.modules():
|
| 958 |
+
if hasattr(module, 'memory') and module.memory is not None:
|
| 959 |
+
if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps):
|
| 960 |
+
module.memory = None
|
| 961 |
+
cleared_count += 1
|
| 962 |
+
if cleared_count > 0 and debug:
|
| 963 |
+
debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success")
|
| 964 |
+
|
| 965 |
+
# End timer
|
| 966 |
+
if debug:
|
| 967 |
+
debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}")
|
| 968 |
+
|
| 969 |
+
return True
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def clear_runtime_caches(runner: Any, debug: Optional['Debug'] = None) -> int:
|
| 973 |
+
"""
|
| 974 |
+
Clear all runtime caches and temporary attributes.
|
| 975 |
+
"""
|
| 976 |
+
if not runner:
|
| 977 |
+
return 0
|
| 978 |
+
|
| 979 |
+
if debug:
|
| 980 |
+
debug.start_timer("runtime_cache_clear")
|
| 981 |
+
|
| 982 |
+
cleaned_items = 0
|
| 983 |
+
|
| 984 |
+
# 1. Clear main runner cache
|
| 985 |
+
if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
|
| 986 |
+
if debug:
|
| 987 |
+
debug.start_timer("runner_cache_clear")
|
| 988 |
+
|
| 989 |
+
cache_entries = len(runner.cache.cache)
|
| 990 |
+
|
| 991 |
+
# Properly release tensor memory and delete as we go
|
| 992 |
+
for key in list(runner.cache.cache.keys()):
|
| 993 |
+
value = runner.cache.cache[key]
|
| 994 |
+
if torch.is_tensor(value):
|
| 995 |
+
release_tensor_memory(value)
|
| 996 |
+
elif isinstance(value, (list, tuple)):
|
| 997 |
+
for item in value:
|
| 998 |
+
if torch.is_tensor(item):
|
| 999 |
+
release_tensor_memory(item)
|
| 1000 |
+
# Delete immediately to release reference
|
| 1001 |
+
del runner.cache.cache[key]
|
| 1002 |
+
|
| 1003 |
+
# Final clear for safety
|
| 1004 |
+
runner.cache.cache.clear()
|
| 1005 |
+
cleaned_items += cache_entries
|
| 1006 |
+
|
| 1007 |
+
if debug:
|
| 1008 |
+
debug.end_timer("runner_cache_clear", f"Clearing main runner cache entries")
|
| 1009 |
+
|
| 1010 |
+
if cache_entries > 0:
|
| 1011 |
+
debug.log(f"Cleared {cache_entries} runtime cache entries", category="success")
|
| 1012 |
+
|
| 1013 |
+
# 2. Clear RoPE caches
|
| 1014 |
+
if hasattr(runner, 'dit'):
|
| 1015 |
+
if debug:
|
| 1016 |
+
debug.start_timer("rope_cache_clear")
|
| 1017 |
+
|
| 1018 |
+
model = runner.dit
|
| 1019 |
+
if hasattr(model, 'dit_model'): # Handle wrapper
|
| 1020 |
+
model = model.dit_model
|
| 1021 |
+
|
| 1022 |
+
rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
|
| 1023 |
+
cleaned_items += rope_cleared
|
| 1024 |
+
if debug:
|
| 1025 |
+
debug.end_timer("rope_cache_clear", "Clearing RoPE LRU caches")
|
| 1026 |
+
|
| 1027 |
+
if rope_cleared > 0:
|
| 1028 |
+
debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
|
| 1029 |
+
|
| 1030 |
+
# 3. Clear temporary attributes
|
| 1031 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 1032 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 1033 |
+
|
| 1034 |
+
for obj in [runner, getattr(runner, 'dit', None), getattr(runner, 'vae', None)]:
|
| 1035 |
+
if obj is None:
|
| 1036 |
+
continue
|
| 1037 |
+
|
| 1038 |
+
actual_obj = obj.dit_model if hasattr(obj, 'dit_model') else obj
|
| 1039 |
+
|
| 1040 |
+
for attr in temp_attrs:
|
| 1041 |
+
if hasattr(actual_obj, attr):
|
| 1042 |
+
delattr(actual_obj, attr)
|
| 1043 |
+
cleaned_items += 1
|
| 1044 |
+
|
| 1045 |
+
if debug:
|
| 1046 |
+
debug.end_timer("runtime_cache_clear", f"clear_runtime_caches() completion")
|
| 1047 |
+
|
| 1048 |
+
return cleaned_items
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
|
| 1052 |
+
"""
|
| 1053 |
+
Cleanup DiT model and BlockSwap state after upscaling phase.
|
| 1054 |
+
Called at the end of upscale_all_batches when DiT is no longer needed.
|
| 1055 |
+
|
| 1056 |
+
Args:
|
| 1057 |
+
runner: Runner instance containing DiT model
|
| 1058 |
+
debug: Debug instance for logging
|
| 1059 |
+
cache_model: If True, move DiT to offload_device; if False, delete completely
|
| 1060 |
+
"""
|
| 1061 |
+
if not runner or not hasattr(runner, 'dit'):
|
| 1062 |
+
return
|
| 1063 |
+
|
| 1064 |
+
if debug:
|
| 1065 |
+
debug.log("Cleaning up DiT components", category="cleanup")
|
| 1066 |
+
|
| 1067 |
+
# 1. Clear DiT-specific runtime caches first
|
| 1068 |
+
if hasattr(runner, 'dit'):
|
| 1069 |
+
model = runner.dit
|
| 1070 |
+
if hasattr(model, 'dit_model'): # Handle wrapper
|
| 1071 |
+
model = model.dit_model
|
| 1072 |
+
|
| 1073 |
+
# Clear RoPE caches
|
| 1074 |
+
rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
|
| 1075 |
+
if rope_cleared > 0 and debug:
|
| 1076 |
+
debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
|
| 1077 |
+
|
| 1078 |
+
# Clear DiT temporary attributes
|
| 1079 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 1080 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 1081 |
+
|
| 1082 |
+
actual_obj = model.dit_model if hasattr(model, 'dit_model') else model
|
| 1083 |
+
for attr in temp_attrs:
|
| 1084 |
+
if hasattr(actual_obj, attr):
|
| 1085 |
+
delattr(actual_obj, attr)
|
| 1086 |
+
|
| 1087 |
+
# 2. Handle model offloading (for caching or before deletion)
|
| 1088 |
+
try:
|
| 1089 |
+
param_device = next(runner.dit.parameters()).device
|
| 1090 |
+
|
| 1091 |
+
# Move model off GPU if needed
|
| 1092 |
+
if param_device.type not in ['meta', 'cpu']:
|
| 1093 |
+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
|
| 1094 |
+
if param_device.type == 'mps' and not cache_model:
|
| 1095 |
+
if debug:
|
| 1096 |
+
debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup")
|
| 1097 |
+
else:
|
| 1098 |
+
offload_target = getattr(runner, '_dit_offload_device', None)
|
| 1099 |
+
if offload_target is None or offload_target == 'none':
|
| 1100 |
+
offload_target = torch.device('cpu')
|
| 1101 |
+
reason = "model caching" if cache_model else "releasing GPU memory"
|
| 1102 |
+
manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
|
| 1103 |
+
debug=debug, reason=reason, runner=runner)
|
| 1104 |
+
elif param_device.type == 'meta' and debug:
|
| 1105 |
+
debug.log("DiT on meta device - keeping structure for cache", category="cleanup")
|
| 1106 |
+
except StopIteration:
|
| 1107 |
+
pass
|
| 1108 |
+
|
| 1109 |
+
# 3. Clean BlockSwap after model movement
|
| 1110 |
+
if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
|
| 1111 |
+
# Import here to avoid circular dependency
|
| 1112 |
+
from .blockswap import cleanup_blockswap
|
| 1113 |
+
|
| 1114 |
+
# If model had a swap stream, synchronize before cleanup to avoid races
|
| 1115 |
+
try:
|
| 1116 |
+
model_for_sync = runner.dit.dit_model if hasattr(runner.dit, 'dit_model') else runner.dit
|
| 1117 |
+
if hasattr(model_for_sync, "_swap_stream"):
|
| 1118 |
+
try:
|
| 1119 |
+
model_for_sync._swap_stream.synchronize()
|
| 1120 |
+
except Exception:
|
| 1121 |
+
if debug:
|
| 1122 |
+
debug.log("Failed to synchronize model._swap_stream before cleanup_blockswap", level="WARNING", category="cleanup", force=True)
|
| 1123 |
+
except Exception:
|
| 1124 |
+
pass
|
| 1125 |
+
|
| 1126 |
+
cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model)
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
# 4. Complete cleanup if not caching
|
| 1130 |
+
if not cache_model:
|
| 1131 |
+
release_model_memory(model=runner.dit, debug=debug)
|
| 1132 |
+
runner.dit = None
|
| 1133 |
+
if debug:
|
| 1134 |
+
debug.log("DiT model deleted", category="cleanup")
|
| 1135 |
+
|
| 1136 |
+
# Clear DiT config attributes - not needed when model is not cached (will be recreated)
|
| 1137 |
+
if hasattr(runner, '_dit_compile_args'):
|
| 1138 |
+
delattr(runner, '_dit_compile_args')
|
| 1139 |
+
if hasattr(runner, '_dit_block_swap_config'):
|
| 1140 |
+
delattr(runner, '_dit_block_swap_config')
|
| 1141 |
+
if hasattr(runner, '_dit_attention_mode'):
|
| 1142 |
+
delattr(runner, '_dit_attention_mode')
|
| 1143 |
+
|
| 1144 |
+
# 5. Clear DiT temporary attributes (should be already cleared in materialize_model)
|
| 1145 |
+
runner._dit_checkpoint = None
|
| 1146 |
+
runner._dit_dtype_override = None
|
| 1147 |
+
|
| 1148 |
+
# 6. Clear DiT-related components and temporary attributes
|
| 1149 |
+
runner.sampler = None
|
| 1150 |
+
runner.sampling_timesteps = None
|
| 1151 |
+
runner.schedule = None
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
|
| 1155 |
+
"""
|
| 1156 |
+
Cleanup VAE model after decoding phase.
|
| 1157 |
+
Called at the end of decode_all_batches when VAE is no longer needed.
|
| 1158 |
+
|
| 1159 |
+
Args:
|
| 1160 |
+
runner: Runner instance containing VAE model
|
| 1161 |
+
debug: Debug instance for logging
|
| 1162 |
+
cache_model: If True, move VAE to offload_device; if False, delete completely
|
| 1163 |
+
"""
|
| 1164 |
+
if not runner or not hasattr(runner, 'vae'):
|
| 1165 |
+
return
|
| 1166 |
+
|
| 1167 |
+
if debug:
|
| 1168 |
+
debug.log("Cleaning up VAE components", category="cleanup")
|
| 1169 |
+
|
| 1170 |
+
# 1. Clear VAE-specific temporary attributes
|
| 1171 |
+
if hasattr(runner, 'vae'):
|
| 1172 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 1173 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 1174 |
+
|
| 1175 |
+
for attr in temp_attrs:
|
| 1176 |
+
if hasattr(runner.vae, attr):
|
| 1177 |
+
delattr(runner.vae, attr)
|
| 1178 |
+
|
| 1179 |
+
# 2. Handle model offloading (for caching or before deletion)
|
| 1180 |
+
try:
|
| 1181 |
+
param_device = next(runner.vae.parameters()).device
|
| 1182 |
+
|
| 1183 |
+
# Move model off GPU if needed
|
| 1184 |
+
if param_device.type not in ['meta', 'cpu']:
|
| 1185 |
+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
|
| 1186 |
+
if param_device.type == 'mps' and not cache_model:
|
| 1187 |
+
if debug:
|
| 1188 |
+
debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup")
|
| 1189 |
+
else:
|
| 1190 |
+
offload_target = getattr(runner, '_vae_offload_device', None)
|
| 1191 |
+
if offload_target is None or offload_target == 'none':
|
| 1192 |
+
offload_target = torch.device('cpu')
|
| 1193 |
+
reason = "model caching" if cache_model else "releasing GPU memory"
|
| 1194 |
+
manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
|
| 1195 |
+
debug=debug, reason=reason, runner=runner)
|
| 1196 |
+
elif param_device.type == 'meta' and debug:
|
| 1197 |
+
debug.log("VAE on meta device - keeping structure for cache", category="cleanup")
|
| 1198 |
+
except StopIteration:
|
| 1199 |
+
pass
|
| 1200 |
+
|
| 1201 |
+
# 3. Complete cleanup if not caching
|
| 1202 |
+
if not cache_model:
|
| 1203 |
+
release_model_memory(model=runner.vae, debug=debug)
|
| 1204 |
+
runner.vae = None
|
| 1205 |
+
if debug:
|
| 1206 |
+
debug.log("VAE model deleted", category="cleanup")
|
| 1207 |
+
|
| 1208 |
+
# Clear VAE config attributes - not needed when model is not cached (will be recreated)
|
| 1209 |
+
if hasattr(runner, '_vae_compile_args'):
|
| 1210 |
+
delattr(runner, '_vae_compile_args')
|
| 1211 |
+
if hasattr(runner, '_vae_tiling_config'):
|
| 1212 |
+
delattr(runner, '_vae_tiling_config')
|
| 1213 |
+
|
| 1214 |
+
# 3. Clear VAE temporary attributes (should be already cleared in materialize_model)
|
| 1215 |
+
runner._vae_checkpoint = None
|
| 1216 |
+
runner._vae_dtype_override = None
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bool = False, vae_cache: bool = False) -> None:
|
| 1220 |
+
"""
|
| 1221 |
+
Complete cleanup of runner and remaining components with independent model caching support.
|
| 1222 |
+
This is a lightweight cleanup for final stage, as model-specific cleanup
|
| 1223 |
+
happens in their respective phases (cleanup_dit, cleanup_vae).
|
| 1224 |
+
|
| 1225 |
+
Args:
|
| 1226 |
+
runner: Runner instance to clean up
|
| 1227 |
+
debug: Debug instance for logging
|
| 1228 |
+
dit_cache: If True, preserve DiT model on offload_device for future runs
|
| 1229 |
+
vae_cache: If True, preserve VAE model on offload_device for future runs
|
| 1230 |
+
|
| 1231 |
+
Behavior:
|
| 1232 |
+
- Can cache DiT and VAE independently for flexible memory management
|
| 1233 |
+
- Preserves _dit_model_name and _vae_model_name when either model is cached for change detection
|
| 1234 |
+
- Clears all temporary attributes and runtime caches
|
| 1235 |
+
- Performs deep memory cleanup only when both models are fully released
|
| 1236 |
+
|
| 1237 |
+
Note:
|
| 1238 |
+
Model name tracking (_dit_model_name, _vae_model_name) is only cleared if neither
|
| 1239 |
+
model is cached, enabling proper model change detection on subsequent runs.
|
| 1240 |
+
"""
|
| 1241 |
+
if not runner:
|
| 1242 |
+
return
|
| 1243 |
+
|
| 1244 |
+
if debug:
|
| 1245 |
+
cleanup_type = "partial cleanup" if (dit_cache or vae_cache) else "full cleanup"
|
| 1246 |
+
debug.log(f"Starting {cleanup_type}", category="cleanup")
|
| 1247 |
+
|
| 1248 |
+
# 1. Cleanup any remaining models if they still exist
|
| 1249 |
+
# (This handles cases where phases were skipped or errored)
|
| 1250 |
+
if hasattr(runner, 'dit') and runner.dit is not None:
|
| 1251 |
+
cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache)
|
| 1252 |
+
|
| 1253 |
+
if hasattr(runner, 'vae') and runner.vae is not None:
|
| 1254 |
+
cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache)
|
| 1255 |
+
|
| 1256 |
+
# 2. Clear remaining runtime caches
|
| 1257 |
+
clear_runtime_caches(runner=runner, debug=debug)
|
| 1258 |
+
|
| 1259 |
+
# 3. Clear config and other non-model components when fully releasing runner
|
| 1260 |
+
if not (dit_cache or vae_cache):
|
| 1261 |
+
# Full cleanup - clear config and model tracking
|
| 1262 |
+
runner.config = None
|
| 1263 |
+
runner._dit_model_name = None
|
| 1264 |
+
runner._vae_model_name = None
|
| 1265 |
+
|
| 1266 |
+
# 4. Final memory cleanup
|
| 1267 |
+
clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup")
|
| 1268 |
+
|
| 1269 |
+
# 5. Clear cuBLAS workspaces
|
| 1270 |
+
torch._C._cuda_clearCublasWorkspaces() if hasattr(torch._C, '_cuda_clearCublasWorkspaces') else None
|
| 1271 |
+
|
| 1272 |
+
# Log what models are cached for next run
|
| 1273 |
+
if dit_cache or vae_cache:
|
| 1274 |
+
cached_models = []
|
| 1275 |
+
if dit_cache and hasattr(runner, '_dit_model_name'):
|
| 1276 |
+
cached_models.append(f"DiT ({runner._dit_model_name})")
|
| 1277 |
+
if vae_cache and hasattr(runner, '_vae_model_name'):
|
| 1278 |
+
cached_models.append(f"VAE ({runner._vae_model_name})")
|
| 1279 |
+
|
| 1280 |
+
if cached_models:
|
| 1281 |
+
models_str = " and ".join(cached_models)
|
| 1282 |
+
debug.log(f"Models cached for next run: {models_str}", category="cache", force=True)
|
| 1283 |
+
|
| 1284 |
+
if debug:
|
| 1285 |
+
debug.log(f"Completed {cleanup_type}", category="success")
|
src/optimization/memory_manager.py.bak
ADDED
|
@@ -0,0 +1,1231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Memory management module for SeedVR2
|
| 3 |
+
Handles VRAM usage, cache management, and memory optimization
|
| 4 |
+
|
| 5 |
+
Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import gc
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
import psutil
|
| 13 |
+
import platform
|
| 14 |
+
from typing import Tuple, Dict, Any, Optional, List, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _device_str(device: Union[torch.device, str]) -> str:
|
| 18 |
+
"""Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
|
| 19 |
+
s = str(device).upper()
|
| 20 |
+
return 'MPS' if s.startswith('MPS') else s
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_mps_available() -> bool:
|
| 24 |
+
"""Check if MPS (Apple Metal) backend is available."""
|
| 25 |
+
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_cuda_available() -> bool:
|
| 29 |
+
"""Check if CUDA backend is available."""
|
| 30 |
+
return torch.cuda.is_available()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_gpu_backend() -> str:
|
| 34 |
+
"""Get the active GPU backend type.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
'cuda': NVIDIA CUDA
|
| 38 |
+
'mps': Apple Metal Performance Shaders
|
| 39 |
+
'cpu': No GPU backend available
|
| 40 |
+
"""
|
| 41 |
+
if is_cuda_available():
|
| 42 |
+
return 'cuda'
|
| 43 |
+
if is_mps_available():
|
| 44 |
+
return 'mps'
|
| 45 |
+
return 'cpu'
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Get list of available compute devices for SeedVR2
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
include_none: If True, prepend "none" to the device list (for offload options)
|
| 54 |
+
include_cpu: If True, include "cpu" in the device list (for offload options only)
|
| 55 |
+
Note: On MPS-only systems, "cpu" is automatically excluded since
|
| 56 |
+
unified memory architecture makes CPU offloading meaningless
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
List of device strings (e.g., ["cuda:0", "cuda:1"] or ["none", "cpu", "cuda:0", "cuda:1"])
|
| 60 |
+
"""
|
| 61 |
+
devs = []
|
| 62 |
+
has_cuda = False
|
| 63 |
+
has_mps = False
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
if is_cuda_available():
|
| 67 |
+
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
| 68 |
+
has_cuda = True
|
| 69 |
+
except Exception:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if is_mps_available():
|
| 74 |
+
devs.append("mps") # MPS doesn't use device indices
|
| 75 |
+
has_mps = True
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
# Build result list with optional prefixes
|
| 80 |
+
result = []
|
| 81 |
+
if include_none:
|
| 82 |
+
result.append("none")
|
| 83 |
+
|
| 84 |
+
# Only include "cpu" option if:
|
| 85 |
+
# 1. It was requested (include_cpu=True), AND
|
| 86 |
+
# 2. Either CUDA is available OR MPS is not the only option
|
| 87 |
+
# Rationale: On MPS-only systems with unified memory architecture,
|
| 88 |
+
# CPU offloading is semantically meaningless as CPU and GPU share the same memory pool
|
| 89 |
+
if include_cpu and (has_cuda or not has_mps):
|
| 90 |
+
result.append("cpu")
|
| 91 |
+
|
| 92 |
+
result.extend(devs)
|
| 93 |
+
|
| 94 |
+
return result if result else []
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]:
|
| 98 |
+
"""
|
| 99 |
+
Get basic VRAM availability info (free and total memory).
|
| 100 |
+
Used for capacity planning and initial checks.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
device: Optional device to query. If None, uses cuda:0
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
dict: {"free_gb": float, "total_gb": float} or {"error": str}
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
if is_cuda_available():
|
| 110 |
+
if device is None:
|
| 111 |
+
device = torch.device("cuda:0")
|
| 112 |
+
elif not isinstance(device, torch.device):
|
| 113 |
+
device = torch.device(device)
|
| 114 |
+
free_memory, total_memory = torch.cuda.mem_get_info(device)
|
| 115 |
+
elif is_mps_available():
|
| 116 |
+
# MPS doesn't support per-device queries or mem_get_info
|
| 117 |
+
# Use system memory as proxy
|
| 118 |
+
mem = psutil.virtual_memory()
|
| 119 |
+
free_memory = mem.total - mem.used
|
| 120 |
+
total_memory = mem.total
|
| 121 |
+
else:
|
| 122 |
+
return {"error": "No GPU backend available (CUDA/MPS)"}
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"free_gb": free_memory / (1024**3),
|
| 126 |
+
"total_gb": total_memory / (1024**3)
|
| 127 |
+
}
|
| 128 |
+
except Exception as e:
|
| 129 |
+
return {"error": f"Failed to get memory info: {str(e)}"}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Initial VRAM check at module load
|
| 133 |
+
vram_info = get_basic_vram_info(device=None)
|
| 134 |
+
if "error" not in vram_info:
|
| 135 |
+
backend = "MPS" if is_mps_available() else "CUDA"
|
| 136 |
+
print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total")
|
| 137 |
+
else:
|
| 138 |
+
print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
|
| 142 |
+
"""
|
| 143 |
+
Get current VRAM usage metrics for monitoring.
|
| 144 |
+
Used for tracking memory consumption during processing.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
device: Optional device to query. If None, uses cuda:0
|
| 148 |
+
debug: Optional debug instance for logging
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple: (allocated_gb, reserved_gb, peak_allocated_gb, peak_reserved_gb)
|
| 152 |
+
Returns (0, 0, 0, 0) if no GPU available
|
| 153 |
+
"""
|
| 154 |
+
try:
|
| 155 |
+
if is_cuda_available():
|
| 156 |
+
if device is None:
|
| 157 |
+
device = torch.device("cuda:0")
|
| 158 |
+
elif not isinstance(device, torch.device):
|
| 159 |
+
device = torch.device(device)
|
| 160 |
+
allocated = torch.cuda.memory_allocated(device) / (1024**3)
|
| 161 |
+
reserved = torch.cuda.memory_reserved(device) / (1024**3)
|
| 162 |
+
peak_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
|
| 163 |
+
peak_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
|
| 164 |
+
return allocated, reserved, peak_allocated, peak_reserved
|
| 165 |
+
elif is_mps_available():
|
| 166 |
+
# MPS doesn't support per-device queries - uses global memory tracking
|
| 167 |
+
allocated = torch.mps.current_allocated_memory() / (1024**3)
|
| 168 |
+
reserved = torch.mps.driver_allocated_memory() / (1024**3)
|
| 169 |
+
# MPS doesn't track peak separately
|
| 170 |
+
return allocated, reserved, allocated, reserved
|
| 171 |
+
except Exception as e:
|
| 172 |
+
if debug:
|
| 173 |
+
debug.log(f"Failed to get VRAM usage: {e}", level="WARNING", category="memory", force=True)
|
| 174 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_ram_usage(debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]:
|
| 178 |
+
"""
|
| 179 |
+
Get current RAM usage metrics for the current process.
|
| 180 |
+
Provides accurate tracking of process-specific memory consumption.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
debug: Optional debug instance for logging
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
tuple: (process_gb, available_gb, total_gb, used_by_others_gb)
|
| 187 |
+
Returns (0, 0, 0, 0) if psutil not available or on error
|
| 188 |
+
"""
|
| 189 |
+
try:
|
| 190 |
+
if not psutil:
|
| 191 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 192 |
+
|
| 193 |
+
# Get current process memory
|
| 194 |
+
process = psutil.Process()
|
| 195 |
+
process_memory = process.memory_info()
|
| 196 |
+
process_gb = process_memory.rss / (1024**3)
|
| 197 |
+
|
| 198 |
+
# Get system memory
|
| 199 |
+
sys_memory = psutil.virtual_memory()
|
| 200 |
+
total_gb = sys_memory.total / (1024**3)
|
| 201 |
+
available_gb = sys_memory.available / (1024**3)
|
| 202 |
+
|
| 203 |
+
# Calculate memory used by other processes
|
| 204 |
+
# This is the CORRECT calculation:
|
| 205 |
+
total_used_gb = total_gb - available_gb # Total memory used by ALL processes
|
| 206 |
+
used_by_others_gb = max(0, total_used_gb - process_gb) # Subtract current process
|
| 207 |
+
|
| 208 |
+
return process_gb, available_gb, total_gb, used_by_others_gb
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
if debug:
|
| 212 |
+
debug.log(f"Failed to get RAM usage: {e}", level="WARNING", category="memory", force=True)
|
| 213 |
+
return 0.0, 0.0, 0.0, 0.0
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Global cache for OS libraries (initialized once)
|
| 217 |
+
_os_memory_lib = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: bool = True,
|
| 221 |
+
timer_name: Optional[str] = None) -> None:
|
| 222 |
+
"""
|
| 223 |
+
Clear memory caches with two-tier approach for optimal performance.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
debug: Debug instance for logging (optional)
|
| 227 |
+
force: If True, always clear. If False, only clear when <5% free
|
| 228 |
+
deep: If True, perform deep cleanup including GC and OS operations.
|
| 229 |
+
If False (default), only perform minimal GPU cache clearing.
|
| 230 |
+
timer_name: Optional suffix for timer names to make them unique per invocation
|
| 231 |
+
|
| 232 |
+
Two-tier approach:
|
| 233 |
+
- Minimal mode (deep=False): GPU cache operations (~1-5ms)
|
| 234 |
+
Used for frequent calls during batch processing
|
| 235 |
+
- Deep mode (deep=True): Complete cleanup with GC and OS operations (~10-50ms)
|
| 236 |
+
Used at key points like model switches or final cleanup
|
| 237 |
+
"""
|
| 238 |
+
global _os_memory_lib
|
| 239 |
+
|
| 240 |
+
# Create unique timer names if suffix provided
|
| 241 |
+
if timer_name:
|
| 242 |
+
main_timer = f"memory_clear_{timer_name}"
|
| 243 |
+
gpu_timer = f"gpu_cache_clear_{timer_name}"
|
| 244 |
+
gc_timer = f"garbage_collection_{timer_name}"
|
| 245 |
+
os_timer = f"os_memory_release_{timer_name}"
|
| 246 |
+
completion_msg = f"clear_memory() completion ({timer_name})"
|
| 247 |
+
else:
|
| 248 |
+
main_timer = "memory_clear"
|
| 249 |
+
gpu_timer = "gpu_cache_clear"
|
| 250 |
+
gc_timer = "garbage_collection"
|
| 251 |
+
os_timer = "os_memory_release"
|
| 252 |
+
completion_msg = "clear_memory() completion"
|
| 253 |
+
|
| 254 |
+
# Start timer for entire operation
|
| 255 |
+
if debug:
|
| 256 |
+
debug.start_timer(main_timer)
|
| 257 |
+
|
| 258 |
+
# Check if we should clear based on memory pressure
|
| 259 |
+
if not force:
|
| 260 |
+
should_clear = False
|
| 261 |
+
|
| 262 |
+
# Use existing function for memory info
|
| 263 |
+
mem_info = get_basic_vram_info(device=None)
|
| 264 |
+
|
| 265 |
+
if "error" not in mem_info and mem_info["total_gb"] > 0:
|
| 266 |
+
# Check VRAM/MPS memory pressure (5% free threshold)
|
| 267 |
+
free_ratio = mem_info["free_gb"] / mem_info["total_gb"]
|
| 268 |
+
if free_ratio < 0.05:
|
| 269 |
+
should_clear = True
|
| 270 |
+
if debug:
|
| 271 |
+
backend = "Unified Memory" if is_mps_available() else "VRAM"
|
| 272 |
+
debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory")
|
| 273 |
+
|
| 274 |
+
# For non-MPS systems, also check system RAM separately
|
| 275 |
+
if not should_clear and not is_mps_available():
|
| 276 |
+
mem = psutil.virtual_memory()
|
| 277 |
+
if mem.available < mem.total * 0.05:
|
| 278 |
+
should_clear = True
|
| 279 |
+
if debug:
|
| 280 |
+
debug.log(f"RAM pressure: {mem.available/(1024**3):.2f}GB free of {mem.total/(1024**3):.2f}GB", category="memory")
|
| 281 |
+
|
| 282 |
+
if not should_clear:
|
| 283 |
+
# End timer before early return to keep stack clean
|
| 284 |
+
if debug:
|
| 285 |
+
debug.end_timer(main_timer)
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
# Determine cleanup level
|
| 289 |
+
cleanup_mode = "deep" if deep else "minimal"
|
| 290 |
+
if debug:
|
| 291 |
+
debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup")
|
| 292 |
+
|
| 293 |
+
# ===== MINIMAL OPERATIONS (Always performed) =====
|
| 294 |
+
# Step 1: Clear GPU caches - Fast operations (~1-5ms)
|
| 295 |
+
if debug:
|
| 296 |
+
debug.start_timer(gpu_timer)
|
| 297 |
+
|
| 298 |
+
if is_cuda_available():
|
| 299 |
+
torch.cuda.empty_cache()
|
| 300 |
+
torch.cuda.ipc_collect()
|
| 301 |
+
elif is_mps_available():
|
| 302 |
+
torch.mps.empty_cache()
|
| 303 |
+
|
| 304 |
+
if debug:
|
| 305 |
+
debug.end_timer(gpu_timer, "GPU cache clearing")
|
| 306 |
+
|
| 307 |
+
# ===== DEEP OPERATIONS (Only when deep=True) =====
|
| 308 |
+
if deep:
|
| 309 |
+
# Step 2: Deep garbage collection (expensive ~5-20ms)
|
| 310 |
+
if debug:
|
| 311 |
+
debug.start_timer(gc_timer)
|
| 312 |
+
|
| 313 |
+
gc.collect(2)
|
| 314 |
+
|
| 315 |
+
if debug:
|
| 316 |
+
debug.end_timer(gc_timer, "Garbage collection")
|
| 317 |
+
|
| 318 |
+
# Step 3: Return memory to OS (platform-specific, ~5-30ms)
|
| 319 |
+
if debug:
|
| 320 |
+
debug.start_timer(os_timer)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
if sys.platform == 'linux':
|
| 324 |
+
# Linux: malloc_trim
|
| 325 |
+
import ctypes # Import only when needed
|
| 326 |
+
if _os_memory_lib is None:
|
| 327 |
+
_os_memory_lib = ctypes.CDLL("libc.so.6")
|
| 328 |
+
_os_memory_lib.malloc_trim(0)
|
| 329 |
+
|
| 330 |
+
elif sys.platform == 'win32':
|
| 331 |
+
# Windows: Trim working set
|
| 332 |
+
import ctypes # Import only when needed
|
| 333 |
+
if _os_memory_lib is None:
|
| 334 |
+
_os_memory_lib = ctypes.windll.kernel32
|
| 335 |
+
handle = _os_memory_lib.GetCurrentProcess()
|
| 336 |
+
_os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1)
|
| 337 |
+
|
| 338 |
+
elif is_mps_available():
|
| 339 |
+
# macOS with MPS
|
| 340 |
+
import ctypes # Import only when needed
|
| 341 |
+
import ctypes.util
|
| 342 |
+
if _os_memory_lib is None:
|
| 343 |
+
libc_path = ctypes.util.find_library('c')
|
| 344 |
+
if libc_path:
|
| 345 |
+
_os_memory_lib = ctypes.CDLL(libc_path)
|
| 346 |
+
|
| 347 |
+
if _os_memory_lib:
|
| 348 |
+
_os_memory_lib.sync()
|
| 349 |
+
except Exception as e:
|
| 350 |
+
if debug:
|
| 351 |
+
debug.log(f"Failed to perform OS memory operations: {e}", level="WARNING", category="memory", force=True)
|
| 352 |
+
|
| 353 |
+
if debug:
|
| 354 |
+
debug.end_timer(os_timer, "OS memory release")
|
| 355 |
+
|
| 356 |
+
# End overall timer
|
| 357 |
+
if debug:
|
| 358 |
+
debug.end_timer(main_timer, completion_msg)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def retry_on_oom(func, *args, debug=None, operation_name="operation", **kwargs):
|
| 362 |
+
"""
|
| 363 |
+
Execute function with single OOM retry after memory cleanup.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
func: Callable to execute
|
| 367 |
+
*args: Positional arguments for func
|
| 368 |
+
debug: Debug instance for logging (optional)
|
| 369 |
+
operation_name: Name for logging
|
| 370 |
+
**kwargs: Keyword arguments for func
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Result of func(*args, **kwargs)
|
| 374 |
+
"""
|
| 375 |
+
try:
|
| 376 |
+
return func(*args, **kwargs)
|
| 377 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| 378 |
+
# Only handle OOM errors
|
| 379 |
+
if not any(x in str(e).lower() for x in ["out of memory", "allocation on device"]):
|
| 380 |
+
raise
|
| 381 |
+
|
| 382 |
+
if debug:
|
| 383 |
+
debug.log(f"OOM during {operation_name}: {e}", level="WARNING", category="memory", force=True)
|
| 384 |
+
debug.log(f"Clearing memory and retrying", category="info", force=True)
|
| 385 |
+
|
| 386 |
+
# Clear memory
|
| 387 |
+
clear_memory(debug=debug, deep=True, force=True, timer_name=operation_name)
|
| 388 |
+
# Let memory settle
|
| 389 |
+
time.sleep(0.5)
|
| 390 |
+
debug.log_memory_state("After memory clearing", show_tensors=False, detailed_tensors=False)
|
| 391 |
+
|
| 392 |
+
# Single retry
|
| 393 |
+
try:
|
| 394 |
+
result = func(*args, **kwargs)
|
| 395 |
+
if debug:
|
| 396 |
+
debug.log(f"Retry successful for {operation_name}", category="success", force=True)
|
| 397 |
+
return result
|
| 398 |
+
except Exception as retry_e:
|
| 399 |
+
if debug:
|
| 400 |
+
debug.log(f"Retry failed for {operation_name}: {retry_e}", level="ERROR", category="memory", force=True)
|
| 401 |
+
raise
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def reset_vram_peak(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> None:
|
| 405 |
+
"""
|
| 406 |
+
Reset VRAM peak memory statistics for fresh tracking.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
device: Optional device to reset stats for. If None, uses cuda:0
|
| 410 |
+
debug: Optional debug instance for logging
|
| 411 |
+
"""
|
| 412 |
+
if debug and debug.enabled:
|
| 413 |
+
debug.log("Resetting VRAM peak memory statistics", category="memory")
|
| 414 |
+
try:
|
| 415 |
+
if is_cuda_available():
|
| 416 |
+
if device is None:
|
| 417 |
+
device = torch.device("cuda:0")
|
| 418 |
+
elif not isinstance(device, torch.device):
|
| 419 |
+
device = torch.device(device)
|
| 420 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 421 |
+
# Note: MPS doesn't support peak memory reset - no action needed
|
| 422 |
+
except Exception as e:
|
| 423 |
+
if debug and debug.enabled:
|
| 424 |
+
debug.log(f"Failed to reset peak memory stats: {e}", level="WARNING", category="memory", force=True)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> int:
|
| 428 |
+
"""
|
| 429 |
+
Clear ALL LRU caches from RoPE modules.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
model: PyTorch model to clear caches from
|
| 433 |
+
debug: Optional debug instance for logging
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
Number of caches cleared
|
| 437 |
+
"""
|
| 438 |
+
if model is None:
|
| 439 |
+
return 0
|
| 440 |
+
|
| 441 |
+
cleared_count = 0
|
| 442 |
+
try:
|
| 443 |
+
for name, module in model.named_modules():
|
| 444 |
+
if hasattr(module, 'get_axial_freqs') and hasattr(module.get_axial_freqs, 'cache_clear'):
|
| 445 |
+
try:
|
| 446 |
+
module.get_axial_freqs.cache_clear()
|
| 447 |
+
cleared_count += 1
|
| 448 |
+
except Exception as e:
|
| 449 |
+
if debug:
|
| 450 |
+
debug.log(f"Failed to clear RoPE LRU cache for module {name}: {e}", level="WARNING", category="memory", force=True)
|
| 451 |
+
except (AttributeError, RuntimeError) as e:
|
| 452 |
+
if debug:
|
| 453 |
+
debug.log(f"Failed to iterate model modules for RoPE LRU cache clearing: {e}", level="WARNING", category="memory", force=True)
|
| 454 |
+
|
| 455 |
+
return cleared_count
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None:
|
| 459 |
+
"""Release tensor memory from any device (CPU/CUDA/MPS)"""
|
| 460 |
+
if tensor is not None and torch.is_tensor(tensor):
|
| 461 |
+
# Release storage for all devices (CPU, CUDA, MPS)
|
| 462 |
+
if tensor.numel() > 0:
|
| 463 |
+
tensor.data.set_()
|
| 464 |
+
tensor.grad = None
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def release_tensor_collection(collection: Any, recursive: bool = True) -> None:
|
| 468 |
+
"""
|
| 469 |
+
Release GPU memory from tensors in any collection (list, tuple, dict, or single tensor).
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
collection: Tensor, list, tuple, dict, or nested structure to release
|
| 473 |
+
recursive: If True, handle nested structures recursively
|
| 474 |
+
|
| 475 |
+
Examples:
|
| 476 |
+
release_tensor_collection(tensor) # Single tensor
|
| 477 |
+
release_tensor_collection([tensor1, tensor2]) # List of tensors
|
| 478 |
+
release_tensor_collection([[t1, t2], [t3, t4]]) # Nested lists
|
| 479 |
+
release_tensor_collection({'a': tensor}) # Dict values
|
| 480 |
+
"""
|
| 481 |
+
if collection is None:
|
| 482 |
+
return
|
| 483 |
+
|
| 484 |
+
if torch.is_tensor(collection):
|
| 485 |
+
release_tensor_memory(collection)
|
| 486 |
+
elif isinstance(collection, dict):
|
| 487 |
+
for value in collection.values():
|
| 488 |
+
if recursive:
|
| 489 |
+
release_tensor_collection(value, recursive=True)
|
| 490 |
+
elif torch.is_tensor(value):
|
| 491 |
+
release_tensor_memory(value)
|
| 492 |
+
elif isinstance(collection, (list, tuple)):
|
| 493 |
+
for item in collection:
|
| 494 |
+
if recursive:
|
| 495 |
+
release_tensor_collection(item, recursive=True)
|
| 496 |
+
elif torch.is_tensor(item):
|
| 497 |
+
release_tensor_memory(item)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def release_text_embeddings(*embeddings: torch.Tensor, debug: Optional['Debug'] = None, names: Optional[List[str]] = None) -> None:
|
| 501 |
+
"""
|
| 502 |
+
Release memory for text embeddings
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
*embeddings: Variable number of embedding tensors to release
|
| 506 |
+
debug: Optional debug instance for logging
|
| 507 |
+
names: Optional list of names for logging
|
| 508 |
+
"""
|
| 509 |
+
for i, embedding in enumerate(embeddings):
|
| 510 |
+
if embedding is not None:
|
| 511 |
+
release_tensor_memory(embedding)
|
| 512 |
+
if debug and names and i < len(names):
|
| 513 |
+
debug.log(f"Cleaned up {names[i]}", category="cleanup")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None) -> None:
|
| 517 |
+
"""
|
| 518 |
+
Clean up text embeddings from a context dictionary.
|
| 519 |
+
Extracts embeddings, releases memory, and clears the context entry.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
ctx: Context dictionary potentially containing 'text_embeds'
|
| 523 |
+
debug: Optional debug instance for logging
|
| 524 |
+
"""
|
| 525 |
+
if not ctx or not ctx.get('text_embeds'):
|
| 526 |
+
return
|
| 527 |
+
|
| 528 |
+
embeddings = []
|
| 529 |
+
names = []
|
| 530 |
+
for key, embeds_list in ctx['text_embeds'].items():
|
| 531 |
+
if embeds_list:
|
| 532 |
+
embeddings.extend(embeds_list)
|
| 533 |
+
names.append(key)
|
| 534 |
+
|
| 535 |
+
if embeddings:
|
| 536 |
+
release_text_embeddings(embeddings, names, debug)
|
| 537 |
+
|
| 538 |
+
if debug:
|
| 539 |
+
debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup")
|
| 540 |
+
|
| 541 |
+
ctx['text_embeds'] = None
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None:
|
| 545 |
+
"""
|
| 546 |
+
Release all GPU/MPS memory from model in-place without CPU transfer.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
model: PyTorch model to release memory from
|
| 550 |
+
debug: Optional debug instance for logging
|
| 551 |
+
"""
|
| 552 |
+
if model is None:
|
| 553 |
+
return
|
| 554 |
+
|
| 555 |
+
try:
|
| 556 |
+
# Clear gradients first
|
| 557 |
+
model.zero_grad(set_to_none=True)
|
| 558 |
+
|
| 559 |
+
# Release GPU memory directly without CPU transfer
|
| 560 |
+
released_params = 0
|
| 561 |
+
released_buffers = 0
|
| 562 |
+
|
| 563 |
+
for param in model.parameters():
|
| 564 |
+
if param.is_cuda or param.is_mps:
|
| 565 |
+
if param.numel() > 0:
|
| 566 |
+
param.data.set_()
|
| 567 |
+
released_params += 1
|
| 568 |
+
param.grad = None
|
| 569 |
+
|
| 570 |
+
for buffer in model.buffers():
|
| 571 |
+
if buffer.is_cuda or buffer.is_mps:
|
| 572 |
+
if buffer.numel() > 0:
|
| 573 |
+
buffer.data.set_()
|
| 574 |
+
released_buffers += 1
|
| 575 |
+
|
| 576 |
+
if debug and (released_params > 0 or released_buffers > 0):
|
| 577 |
+
debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success")
|
| 578 |
+
|
| 579 |
+
except (AttributeError, RuntimeError) as e:
|
| 580 |
+
if debug:
|
| 581 |
+
debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def manage_tensor(
|
| 585 |
+
tensor: torch.Tensor,
|
| 586 |
+
target_device: torch.device,
|
| 587 |
+
tensor_name: str = "tensor",
|
| 588 |
+
dtype: Optional[torch.dtype] = None,
|
| 589 |
+
non_blocking: bool = False,
|
| 590 |
+
debug: Optional['Debug'] = None,
|
| 591 |
+
reason: Optional[str] = None,
|
| 592 |
+
indent_level: int = 0
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
"""
|
| 595 |
+
Unified tensor management for device movement and dtype conversion.
|
| 596 |
+
|
| 597 |
+
Handles both device transfers (CPU ↔ GPU) and dtype conversions (e.g., float16 → bfloat16)
|
| 598 |
+
with intelligent early-exit optimization and comprehensive logging.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
tensor: Tensor to manage
|
| 602 |
+
target_device: Target device (torch.device object)
|
| 603 |
+
tensor_name: Descriptive name for logging (e.g., "latent", "sample", "alpha_channel")
|
| 604 |
+
dtype: Optional target dtype to cast to (if None, keeps original dtype)
|
| 605 |
+
non_blocking: Whether to use non-blocking transfer
|
| 606 |
+
debug: Debug instance for logging
|
| 607 |
+
reason: Optional reason for the operation (e.g., "inference", "offload", "dtype alignment")
|
| 608 |
+
indent_level: Indentation level for debug logging (0=no indent, 1=2 spaces, etc.)
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
Tensor on target device with optional dtype conversion
|
| 612 |
+
|
| 613 |
+
Note:
|
| 614 |
+
- Skips operation if tensor already has target device and dtype (zero-copy)
|
| 615 |
+
- Uses PyTorch's optimized .to() for efficient device/dtype handling
|
| 616 |
+
- Logs all operations consistently for tracking and debugging
|
| 617 |
+
"""
|
| 618 |
+
if tensor is None:
|
| 619 |
+
return tensor
|
| 620 |
+
|
| 621 |
+
# Get current state
|
| 622 |
+
current_device = tensor.device
|
| 623 |
+
current_dtype = tensor.dtype
|
| 624 |
+
target_dtype = dtype if dtype is not None else current_dtype
|
| 625 |
+
|
| 626 |
+
# Check if movement is actually needed
|
| 627 |
+
needs_device_move = _device_str(current_device) != _device_str(target_device)
|
| 628 |
+
needs_dtype_change = dtype is not None and current_dtype != target_dtype
|
| 629 |
+
|
| 630 |
+
if not needs_device_move and not needs_dtype_change:
|
| 631 |
+
# Already on target device and dtype - skip
|
| 632 |
+
return tensor
|
| 633 |
+
|
| 634 |
+
# Determine reason for movement
|
| 635 |
+
if reason is None:
|
| 636 |
+
if needs_device_move and needs_dtype_change:
|
| 637 |
+
reason = "device and dtype conversion"
|
| 638 |
+
elif needs_device_move:
|
| 639 |
+
reason = "device movement"
|
| 640 |
+
else:
|
| 641 |
+
reason = "dtype conversion"
|
| 642 |
+
|
| 643 |
+
# Log the movement
|
| 644 |
+
if debug:
|
| 645 |
+
current_device_str = _device_str(current_device)
|
| 646 |
+
target_device_str = _device_str(target_device)
|
| 647 |
+
|
| 648 |
+
dtype_info = ""
|
| 649 |
+
if needs_dtype_change:
|
| 650 |
+
dtype_info = f", {current_dtype} → {target_dtype}"
|
| 651 |
+
|
| 652 |
+
debug.log(
|
| 653 |
+
f"Moving {tensor_name} from {current_device_str} to {target_device_str}{dtype_info} ({reason})",
|
| 654 |
+
category="general",
|
| 655 |
+
indent_level=indent_level
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# Perform the operation based on what needs to change
|
| 659 |
+
if needs_device_move and needs_dtype_change:
|
| 660 |
+
# Both device and dtype need to change
|
| 661 |
+
return tensor.to(target_device, dtype=target_dtype, non_blocking=non_blocking)
|
| 662 |
+
elif needs_device_move:
|
| 663 |
+
# Only device needs to change
|
| 664 |
+
return tensor.to(target_device, non_blocking=non_blocking)
|
| 665 |
+
else:
|
| 666 |
+
# Only dtype needs to change
|
| 667 |
+
return tensor.to(dtype=target_dtype)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def manage_model_device(model: torch.nn.Module, target_device: torch.device, model_name: str,
|
| 671 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None,
|
| 672 |
+
runner: Optional[Any] = None) -> bool:
|
| 673 |
+
"""
|
| 674 |
+
Move model to target device with optimizations.
|
| 675 |
+
Handles BlockSwap-enabled models transparently.
|
| 676 |
+
|
| 677 |
+
Args:
|
| 678 |
+
model: The model to move
|
| 679 |
+
target_device: Target device (torch.device object, e.g., torch.device('cuda:0'))
|
| 680 |
+
model_name: Name for logging (e.g., "VAE", "DiT")
|
| 681 |
+
debug: Debug instance for logging
|
| 682 |
+
reason: Optional custom reason for the movement
|
| 683 |
+
runner: Optional runner instance for BlockSwap detection
|
| 684 |
+
|
| 685 |
+
Returns:
|
| 686 |
+
bool: True if model was moved, False if already on target device
|
| 687 |
+
"""
|
| 688 |
+
if model is None:
|
| 689 |
+
return False
|
| 690 |
+
|
| 691 |
+
# Check if this is a BlockSwap-enabled DiT model
|
| 692 |
+
is_blockswap_model = False
|
| 693 |
+
actual_model = model
|
| 694 |
+
if runner and model_name == "DiT":
|
| 695 |
+
# Import here to avoid circular dependency
|
| 696 |
+
from .blockswap import is_blockswap_enabled
|
| 697 |
+
# Check if BlockSwap config exists and is enabled
|
| 698 |
+
has_blockswap_config = (
|
| 699 |
+
hasattr(runner, '_dit_block_swap_config') and
|
| 700 |
+
is_blockswap_enabled(runner._dit_block_swap_config)
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
if has_blockswap_config:
|
| 704 |
+
is_blockswap_model = True
|
| 705 |
+
# Get the actual model (handle CompatibleDiT wrapper)
|
| 706 |
+
if hasattr(model, "dit_model"):
|
| 707 |
+
actual_model = model.dit_model
|
| 708 |
+
|
| 709 |
+
# Get current device
|
| 710 |
+
try:
|
| 711 |
+
current_device = next(model.parameters()).device
|
| 712 |
+
except StopIteration:
|
| 713 |
+
return False
|
| 714 |
+
|
| 715 |
+
# Extract device type for comparison (both are torch.device objects)
|
| 716 |
+
target_type = target_device.type
|
| 717 |
+
current_device_upper = _device_str(current_device)
|
| 718 |
+
target_device_upper = _device_str(target_device)
|
| 719 |
+
|
| 720 |
+
# Compare normalized device types
|
| 721 |
+
if current_device_upper == target_device_upper and not is_blockswap_model:
|
| 722 |
+
# Already on target device type, no movement needed
|
| 723 |
+
if debug:
|
| 724 |
+
debug.log(f"{model_name} already on {current_device_upper}, skipping movement", category="general")
|
| 725 |
+
return False
|
| 726 |
+
|
| 727 |
+
# Handle BlockSwap models specially
|
| 728 |
+
if is_blockswap_model:
|
| 729 |
+
return _handle_blockswap_model_movement(
|
| 730 |
+
runner, actual_model, current_device, target_device, target_type,
|
| 731 |
+
model_name, debug, reason
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Standard model movement (non-BlockSwap)
|
| 735 |
+
return _standard_model_movement(
|
| 736 |
+
model, current_device, target_device, target_type, model_name,
|
| 737 |
+
debug, reason
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
|
| 742 |
+
current_device: torch.device, target_device: torch.device,
|
| 743 |
+
target_type: str, model_name: str,
|
| 744 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
|
| 745 |
+
"""
|
| 746 |
+
Handle device movement for BlockSwap-enabled models.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
runner: Runner instance with BlockSwap configuration
|
| 750 |
+
model: Model to move (actual unwrapped model)
|
| 751 |
+
current_device: Current device of the model
|
| 752 |
+
target_device: Target device (torch.device object)
|
| 753 |
+
target_type: Target device type (cpu/cuda/mps)
|
| 754 |
+
model_name: Model name for logging
|
| 755 |
+
debug: Debug instance
|
| 756 |
+
reason: Movement reason
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
bool: True if model was moved
|
| 760 |
+
"""
|
| 761 |
+
# Import here to avoid circular dependency
|
| 762 |
+
from .blockswap import set_blockswap_bypass
|
| 763 |
+
|
| 764 |
+
if target_type == "cpu":
|
| 765 |
+
# Moving to offload device (typically CPU)
|
| 766 |
+
# Check if any parameter is on GPU (for accurate logging)
|
| 767 |
+
actual_source_device = None
|
| 768 |
+
for param in model.parameters():
|
| 769 |
+
if param.device.type in ['cuda', 'mps']:
|
| 770 |
+
actual_source_device = param.device
|
| 771 |
+
break
|
| 772 |
+
|
| 773 |
+
source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device)
|
| 774 |
+
|
| 775 |
+
if debug:
|
| 776 |
+
debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general")
|
| 777 |
+
|
| 778 |
+
# Enable bypass to allow movement
|
| 779 |
+
set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
|
| 780 |
+
|
| 781 |
+
# Start timer
|
| 782 |
+
timer_name = f"{model_name.lower()}_to_{target_type}"
|
| 783 |
+
if debug:
|
| 784 |
+
debug.start_timer(timer_name)
|
| 785 |
+
|
| 786 |
+
# Move entire model to target offload device
|
| 787 |
+
model.to(target_device)
|
| 788 |
+
model.zero_grad(set_to_none=True)
|
| 789 |
+
|
| 790 |
+
if debug:
|
| 791 |
+
debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
|
| 792 |
+
|
| 793 |
+
return True
|
| 794 |
+
|
| 795 |
+
else:
|
| 796 |
+
# Moving to GPU (reload)
|
| 797 |
+
# Check if we're in bypass mode (coming from offload)
|
| 798 |
+
if not getattr(model, "_blockswap_bypass_protection", False):
|
| 799 |
+
# Not in bypass mode, blocks are already configured
|
| 800 |
+
if debug:
|
| 801 |
+
debug.log(f"{model_name} with BlockSwap active - blocks already distributed across devices, skipping movement", category="general")
|
| 802 |
+
return False
|
| 803 |
+
|
| 804 |
+
# Get actual current device for accurate logging
|
| 805 |
+
actual_current_device = None
|
| 806 |
+
for param in model.parameters():
|
| 807 |
+
if param.device.type != 'meta':
|
| 808 |
+
actual_current_device = param.device
|
| 809 |
+
break
|
| 810 |
+
|
| 811 |
+
current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD"
|
| 812 |
+
|
| 813 |
+
if debug:
|
| 814 |
+
debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general")
|
| 815 |
+
|
| 816 |
+
timer_name = f"{model_name.lower()}_to_gpu"
|
| 817 |
+
if debug:
|
| 818 |
+
debug.start_timer(timer_name)
|
| 819 |
+
|
| 820 |
+
# Restore blocks to their configured devices
|
| 821 |
+
if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"):
|
| 822 |
+
# Use configured offload_device from BlockSwap config
|
| 823 |
+
offload_device = model._block_swap_config.get("offload_device")
|
| 824 |
+
if not offload_device:
|
| 825 |
+
raise ValueError("BlockSwap config missing offload_device")
|
| 826 |
+
|
| 827 |
+
# Move blocks according to BlockSwap configuration
|
| 828 |
+
for b, block in enumerate(model.blocks):
|
| 829 |
+
if b > model.blocks_to_swap:
|
| 830 |
+
# This block should be on GPU
|
| 831 |
+
block.to(target_device)
|
| 832 |
+
else:
|
| 833 |
+
# This block stays on offload device (will be swapped during forward)
|
| 834 |
+
block.to(offload_device)
|
| 835 |
+
|
| 836 |
+
# Handle I/O components
|
| 837 |
+
if not model._block_swap_config.get("swap_io_components", False):
|
| 838 |
+
# I/O components should be on GPU if not offloaded
|
| 839 |
+
for name, module in model.named_children():
|
| 840 |
+
if name != "blocks":
|
| 841 |
+
module.to(target_device)
|
| 842 |
+
else:
|
| 843 |
+
# I/O components stay on offload device
|
| 844 |
+
for name, module in model.named_children():
|
| 845 |
+
if name != "blocks":
|
| 846 |
+
module.to(offload_device)
|
| 847 |
+
|
| 848 |
+
if debug:
|
| 849 |
+
# Get actual configuration from runner
|
| 850 |
+
if hasattr(model, '_block_swap_config'):
|
| 851 |
+
blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16)
|
| 852 |
+
total_blocks = model._block_swap_config.get('total_blocks', 32)
|
| 853 |
+
main_device = model._block_swap_config.get('main_device', 'GPU')
|
| 854 |
+
debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success")
|
| 855 |
+
else:
|
| 856 |
+
debug.log("BlockSwap blocks restored to configured devices", category="success")
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
# Reactivate BlockSwap now that blocks are restored to their configured devices
|
| 860 |
+
runner._blockswap_active = True
|
| 861 |
+
|
| 862 |
+
# Disable bypass, re-enable protection
|
| 863 |
+
set_blockswap_bypass(runner=runner, bypass=False, debug=debug)
|
| 864 |
+
|
| 865 |
+
if debug:
|
| 866 |
+
debug.end_timer(timer_name, "BlockSwap model restored")
|
| 867 |
+
|
| 868 |
+
return True
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
def _standard_model_movement(model: torch.nn.Module, current_device: torch.device,
|
| 872 |
+
target_device: torch.device, target_type: str, model_name: str,
|
| 873 |
+
debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool:
|
| 874 |
+
"""
|
| 875 |
+
Handle standard (non-BlockSwap) model movement.
|
| 876 |
+
|
| 877 |
+
Args:
|
| 878 |
+
model: Model to move
|
| 879 |
+
current_device: Current device of the model
|
| 880 |
+
target_device: Target device (torch.device object)
|
| 881 |
+
target_type: Target device type
|
| 882 |
+
model_name: Model name for logging
|
| 883 |
+
debug: Debug instance
|
| 884 |
+
reason: Movement reason
|
| 885 |
+
|
| 886 |
+
Returns:
|
| 887 |
+
bool: True if model was moved
|
| 888 |
+
"""
|
| 889 |
+
# Check if model is on meta device - can't move meta tensors
|
| 890 |
+
if current_device.type == 'meta':
|
| 891 |
+
if debug:
|
| 892 |
+
debug.log(f"{model_name} is on meta device - skipping movement (will materialize when needed)",
|
| 893 |
+
category=model_name.lower())
|
| 894 |
+
return False
|
| 895 |
+
|
| 896 |
+
# Determine reason for movement
|
| 897 |
+
reason = reason or "inference requirement"
|
| 898 |
+
|
| 899 |
+
# Log the movement with full device strings
|
| 900 |
+
if debug:
|
| 901 |
+
current_device_str = _device_str(current_device)
|
| 902 |
+
target_device_str = _device_str(target_device)
|
| 903 |
+
debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general")
|
| 904 |
+
|
| 905 |
+
# Start timer based on direction
|
| 906 |
+
timer_name = f"{model_name.lower()}_to_{'gpu' if target_type != 'cpu' else 'cpu'}"
|
| 907 |
+
if debug:
|
| 908 |
+
debug.start_timer(timer_name)
|
| 909 |
+
|
| 910 |
+
# Move model and clear gradients
|
| 911 |
+
model.to(target_device)
|
| 912 |
+
model.zero_grad(set_to_none=True)
|
| 913 |
+
|
| 914 |
+
# Clear VAE memory buffers when moving to CPU
|
| 915 |
+
if target_type == 'cpu' and model_name == "VAE":
|
| 916 |
+
cleared_count = 0
|
| 917 |
+
for module in model.modules():
|
| 918 |
+
if hasattr(module, 'memory') and module.memory is not None:
|
| 919 |
+
if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps):
|
| 920 |
+
module.memory = None
|
| 921 |
+
cleared_count += 1
|
| 922 |
+
if cleared_count > 0 and debug:
|
| 923 |
+
debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success")
|
| 924 |
+
|
| 925 |
+
# End timer
|
| 926 |
+
if debug:
|
| 927 |
+
debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}")
|
| 928 |
+
|
| 929 |
+
return True
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
def clear_runtime_caches(runner: Any, debug: Optional['Debug'] = None) -> int:
|
| 933 |
+
"""
|
| 934 |
+
Clear all runtime caches and temporary attributes.
|
| 935 |
+
"""
|
| 936 |
+
if not runner:
|
| 937 |
+
return 0
|
| 938 |
+
|
| 939 |
+
if debug:
|
| 940 |
+
debug.start_timer("runtime_cache_clear")
|
| 941 |
+
|
| 942 |
+
cleaned_items = 0
|
| 943 |
+
|
| 944 |
+
# 1. Clear main runner cache
|
| 945 |
+
if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
|
| 946 |
+
if debug:
|
| 947 |
+
debug.start_timer("runner_cache_clear")
|
| 948 |
+
|
| 949 |
+
cache_entries = len(runner.cache.cache)
|
| 950 |
+
|
| 951 |
+
# Properly release tensor memory and delete as we go
|
| 952 |
+
for key in list(runner.cache.cache.keys()):
|
| 953 |
+
value = runner.cache.cache[key]
|
| 954 |
+
if torch.is_tensor(value):
|
| 955 |
+
release_tensor_memory(value)
|
| 956 |
+
elif isinstance(value, (list, tuple)):
|
| 957 |
+
for item in value:
|
| 958 |
+
if torch.is_tensor(item):
|
| 959 |
+
release_tensor_memory(item)
|
| 960 |
+
# Delete immediately to release reference
|
| 961 |
+
del runner.cache.cache[key]
|
| 962 |
+
|
| 963 |
+
# Final clear for safety
|
| 964 |
+
runner.cache.cache.clear()
|
| 965 |
+
cleaned_items += cache_entries
|
| 966 |
+
|
| 967 |
+
if debug:
|
| 968 |
+
debug.end_timer("runner_cache_clear", f"Clearing main runner cache entries")
|
| 969 |
+
|
| 970 |
+
if cache_entries > 0:
|
| 971 |
+
debug.log(f"Cleared {cache_entries} runtime cache entries", category="success")
|
| 972 |
+
|
| 973 |
+
# 2. Clear RoPE caches
|
| 974 |
+
if hasattr(runner, 'dit'):
|
| 975 |
+
if debug:
|
| 976 |
+
debug.start_timer("rope_cache_clear")
|
| 977 |
+
|
| 978 |
+
model = runner.dit
|
| 979 |
+
if hasattr(model, 'dit_model'): # Handle wrapper
|
| 980 |
+
model = model.dit_model
|
| 981 |
+
|
| 982 |
+
rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
|
| 983 |
+
cleaned_items += rope_cleared
|
| 984 |
+
if debug:
|
| 985 |
+
debug.end_timer("rope_cache_clear", "Clearing RoPE LRU caches")
|
| 986 |
+
|
| 987 |
+
if rope_cleared > 0:
|
| 988 |
+
debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
|
| 989 |
+
|
| 990 |
+
# 3. Clear temporary attributes
|
| 991 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 992 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 993 |
+
|
| 994 |
+
for obj in [runner, getattr(runner, 'dit', None), getattr(runner, 'vae', None)]:
|
| 995 |
+
if obj is None:
|
| 996 |
+
continue
|
| 997 |
+
|
| 998 |
+
actual_obj = obj.dit_model if hasattr(obj, 'dit_model') else obj
|
| 999 |
+
|
| 1000 |
+
for attr in temp_attrs:
|
| 1001 |
+
if hasattr(actual_obj, attr):
|
| 1002 |
+
delattr(actual_obj, attr)
|
| 1003 |
+
cleaned_items += 1
|
| 1004 |
+
|
| 1005 |
+
if debug:
|
| 1006 |
+
debug.end_timer("runtime_cache_clear", f"clear_runtime_caches() completion")
|
| 1007 |
+
|
| 1008 |
+
return cleaned_items
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
|
| 1012 |
+
"""
|
| 1013 |
+
Cleanup DiT model and BlockSwap state after upscaling phase.
|
| 1014 |
+
Called at the end of upscale_all_batches when DiT is no longer needed.
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
runner: Runner instance containing DiT model
|
| 1018 |
+
debug: Debug instance for logging
|
| 1019 |
+
cache_model: If True, move DiT to offload_device; if False, delete completely
|
| 1020 |
+
"""
|
| 1021 |
+
if not runner or not hasattr(runner, 'dit'):
|
| 1022 |
+
return
|
| 1023 |
+
|
| 1024 |
+
if debug:
|
| 1025 |
+
debug.log("Cleaning up DiT components", category="cleanup")
|
| 1026 |
+
|
| 1027 |
+
# 1. Clear DiT-specific runtime caches first
|
| 1028 |
+
if hasattr(runner, 'dit'):
|
| 1029 |
+
model = runner.dit
|
| 1030 |
+
if hasattr(model, 'dit_model'): # Handle wrapper
|
| 1031 |
+
model = model.dit_model
|
| 1032 |
+
|
| 1033 |
+
# Clear RoPE caches
|
| 1034 |
+
rope_cleared = clear_rope_lru_caches(model=model, debug=debug)
|
| 1035 |
+
if rope_cleared > 0 and debug:
|
| 1036 |
+
debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success")
|
| 1037 |
+
|
| 1038 |
+
# Clear DiT temporary attributes
|
| 1039 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 1040 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 1041 |
+
|
| 1042 |
+
actual_obj = model.dit_model if hasattr(model, 'dit_model') else model
|
| 1043 |
+
for attr in temp_attrs:
|
| 1044 |
+
if hasattr(actual_obj, attr):
|
| 1045 |
+
delattr(actual_obj, attr)
|
| 1046 |
+
|
| 1047 |
+
# 2. Handle model offloading (for caching or before deletion)
|
| 1048 |
+
try:
|
| 1049 |
+
param_device = next(runner.dit.parameters()).device
|
| 1050 |
+
|
| 1051 |
+
# Move model off GPU if needed
|
| 1052 |
+
if param_device.type not in ['meta', 'cpu']:
|
| 1053 |
+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
|
| 1054 |
+
if param_device.type == 'mps' and not cache_model:
|
| 1055 |
+
if debug:
|
| 1056 |
+
debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup")
|
| 1057 |
+
else:
|
| 1058 |
+
offload_target = getattr(runner, '_dit_offload_device', None)
|
| 1059 |
+
if offload_target is None or offload_target == 'none':
|
| 1060 |
+
offload_target = torch.device('cpu')
|
| 1061 |
+
reason = "model caching" if cache_model else "releasing GPU memory"
|
| 1062 |
+
manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
|
| 1063 |
+
debug=debug, reason=reason, runner=runner)
|
| 1064 |
+
elif param_device.type == 'meta' and debug:
|
| 1065 |
+
debug.log("DiT on meta device - keeping structure for cache", category="cleanup")
|
| 1066 |
+
except StopIteration:
|
| 1067 |
+
pass
|
| 1068 |
+
|
| 1069 |
+
# 3. Clean BlockSwap after model movement
|
| 1070 |
+
if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
|
| 1071 |
+
# Import here to avoid circular dependency
|
| 1072 |
+
from .blockswap import cleanup_blockswap
|
| 1073 |
+
cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model)
|
| 1074 |
+
|
| 1075 |
+
# 4. Complete cleanup if not caching
|
| 1076 |
+
if not cache_model:
|
| 1077 |
+
release_model_memory(model=runner.dit, debug=debug)
|
| 1078 |
+
runner.dit = None
|
| 1079 |
+
if debug:
|
| 1080 |
+
debug.log("DiT model deleted", category="cleanup")
|
| 1081 |
+
|
| 1082 |
+
# Clear DiT config attributes - not needed when model is not cached (will be recreated)
|
| 1083 |
+
if hasattr(runner, '_dit_compile_args'):
|
| 1084 |
+
delattr(runner, '_dit_compile_args')
|
| 1085 |
+
if hasattr(runner, '_dit_block_swap_config'):
|
| 1086 |
+
delattr(runner, '_dit_block_swap_config')
|
| 1087 |
+
if hasattr(runner, '_dit_attention_mode'):
|
| 1088 |
+
delattr(runner, '_dit_attention_mode')
|
| 1089 |
+
|
| 1090 |
+
# 5. Clear DiT temporary attributes (should be already cleared in materialize_model)
|
| 1091 |
+
runner._dit_checkpoint = None
|
| 1092 |
+
runner._dit_dtype_override = None
|
| 1093 |
+
|
| 1094 |
+
# 6. Clear DiT-related components and temporary attributes
|
| 1095 |
+
runner.sampler = None
|
| 1096 |
+
runner.sampling_timesteps = None
|
| 1097 |
+
runner.schedule = None
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None:
|
| 1101 |
+
"""
|
| 1102 |
+
Cleanup VAE model after decoding phase.
|
| 1103 |
+
Called at the end of decode_all_batches when VAE is no longer needed.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
runner: Runner instance containing VAE model
|
| 1107 |
+
debug: Debug instance for logging
|
| 1108 |
+
cache_model: If True, move VAE to offload_device; if False, delete completely
|
| 1109 |
+
"""
|
| 1110 |
+
if not runner or not hasattr(runner, 'vae'):
|
| 1111 |
+
return
|
| 1112 |
+
|
| 1113 |
+
if debug:
|
| 1114 |
+
debug.log("Cleaning up VAE components", category="cleanup")
|
| 1115 |
+
|
| 1116 |
+
# 1. Clear VAE-specific temporary attributes
|
| 1117 |
+
if hasattr(runner, 'vae'):
|
| 1118 |
+
temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache',
|
| 1119 |
+
'_rope_cache', '_intermediate_cache', '_backward_cache']
|
| 1120 |
+
|
| 1121 |
+
for attr in temp_attrs:
|
| 1122 |
+
if hasattr(runner.vae, attr):
|
| 1123 |
+
delattr(runner.vae, attr)
|
| 1124 |
+
|
| 1125 |
+
# 2. Handle model offloading (for caching or before deletion)
|
| 1126 |
+
try:
|
| 1127 |
+
param_device = next(runner.vae.parameters()).device
|
| 1128 |
+
|
| 1129 |
+
# Move model off GPU if needed
|
| 1130 |
+
if param_device.type not in ['meta', 'cpu']:
|
| 1131 |
+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
|
| 1132 |
+
if param_device.type == 'mps' and not cache_model:
|
| 1133 |
+
if debug:
|
| 1134 |
+
debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup")
|
| 1135 |
+
else:
|
| 1136 |
+
offload_target = getattr(runner, '_vae_offload_device', None)
|
| 1137 |
+
if offload_target is None or offload_target == 'none':
|
| 1138 |
+
offload_target = torch.device('cpu')
|
| 1139 |
+
reason = "model caching" if cache_model else "releasing GPU memory"
|
| 1140 |
+
manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
|
| 1141 |
+
debug=debug, reason=reason, runner=runner)
|
| 1142 |
+
elif param_device.type == 'meta' and debug:
|
| 1143 |
+
debug.log("VAE on meta device - keeping structure for cache", category="cleanup")
|
| 1144 |
+
except StopIteration:
|
| 1145 |
+
pass
|
| 1146 |
+
|
| 1147 |
+
# 3. Complete cleanup if not caching
|
| 1148 |
+
if not cache_model:
|
| 1149 |
+
release_model_memory(model=runner.vae, debug=debug)
|
| 1150 |
+
runner.vae = None
|
| 1151 |
+
if debug:
|
| 1152 |
+
debug.log("VAE model deleted", category="cleanup")
|
| 1153 |
+
|
| 1154 |
+
# Clear VAE config attributes - not needed when model is not cached (will be recreated)
|
| 1155 |
+
if hasattr(runner, '_vae_compile_args'):
|
| 1156 |
+
delattr(runner, '_vae_compile_args')
|
| 1157 |
+
if hasattr(runner, '_vae_tiling_config'):
|
| 1158 |
+
delattr(runner, '_vae_tiling_config')
|
| 1159 |
+
|
| 1160 |
+
# 3. Clear VAE temporary attributes (should be already cleared in materialize_model)
|
| 1161 |
+
runner._vae_checkpoint = None
|
| 1162 |
+
runner._vae_dtype_override = None
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bool = False, vae_cache: bool = False) -> None:
|
| 1166 |
+
"""
|
| 1167 |
+
Complete cleanup of runner and remaining components with independent model caching support.
|
| 1168 |
+
This is a lightweight cleanup for final stage, as model-specific cleanup
|
| 1169 |
+
happens in their respective phases (cleanup_dit, cleanup_vae).
|
| 1170 |
+
|
| 1171 |
+
Args:
|
| 1172 |
+
runner: Runner instance to clean up
|
| 1173 |
+
debug: Debug instance for logging
|
| 1174 |
+
dit_cache: If True, preserve DiT model on offload_device for future runs
|
| 1175 |
+
vae_cache: If True, preserve VAE model on offload_device for future runs
|
| 1176 |
+
|
| 1177 |
+
Behavior:
|
| 1178 |
+
- Can cache DiT and VAE independently for flexible memory management
|
| 1179 |
+
- Preserves _dit_model_name and _vae_model_name when either model is cached for change detection
|
| 1180 |
+
- Clears all temporary attributes and runtime caches
|
| 1181 |
+
- Performs deep memory cleanup only when both models are fully released
|
| 1182 |
+
|
| 1183 |
+
Note:
|
| 1184 |
+
Model name tracking (_dit_model_name, _vae_model_name) is only cleared if neither
|
| 1185 |
+
model is cached, enabling proper model change detection on subsequent runs.
|
| 1186 |
+
"""
|
| 1187 |
+
if not runner:
|
| 1188 |
+
return
|
| 1189 |
+
|
| 1190 |
+
if debug:
|
| 1191 |
+
cleanup_type = "partial cleanup" if (dit_cache or vae_cache) else "full cleanup"
|
| 1192 |
+
debug.log(f"Starting {cleanup_type}", category="cleanup")
|
| 1193 |
+
|
| 1194 |
+
# 1. Cleanup any remaining models if they still exist
|
| 1195 |
+
# (This handles cases where phases were skipped or errored)
|
| 1196 |
+
if hasattr(runner, 'dit') and runner.dit is not None:
|
| 1197 |
+
cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache)
|
| 1198 |
+
|
| 1199 |
+
if hasattr(runner, 'vae') and runner.vae is not None:
|
| 1200 |
+
cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache)
|
| 1201 |
+
|
| 1202 |
+
# 2. Clear remaining runtime caches
|
| 1203 |
+
clear_runtime_caches(runner=runner, debug=debug)
|
| 1204 |
+
|
| 1205 |
+
# 3. Clear config and other non-model components when fully releasing runner
|
| 1206 |
+
if not (dit_cache or vae_cache):
|
| 1207 |
+
# Full cleanup - clear config and model tracking
|
| 1208 |
+
runner.config = None
|
| 1209 |
+
runner._dit_model_name = None
|
| 1210 |
+
runner._vae_model_name = None
|
| 1211 |
+
|
| 1212 |
+
# 4. Final memory cleanup
|
| 1213 |
+
clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup")
|
| 1214 |
+
|
| 1215 |
+
# 5. Clear cuBLAS workspaces
|
| 1216 |
+
torch._C._cuda_clearCublasWorkspaces() if hasattr(torch._C, '_cuda_clearCublasWorkspaces') else None
|
| 1217 |
+
|
| 1218 |
+
# Log what models are cached for next run
|
| 1219 |
+
if dit_cache or vae_cache:
|
| 1220 |
+
cached_models = []
|
| 1221 |
+
if dit_cache and hasattr(runner, '_dit_model_name'):
|
| 1222 |
+
cached_models.append(f"DiT ({runner._dit_model_name})")
|
| 1223 |
+
if vae_cache and hasattr(runner, '_vae_model_name'):
|
| 1224 |
+
cached_models.append(f"VAE ({runner._vae_model_name})")
|
| 1225 |
+
|
| 1226 |
+
if cached_models:
|
| 1227 |
+
models_str = " and ".join(cached_models)
|
| 1228 |
+
debug.log(f"Models cached for next run: {models_str}", category="cache", force=True)
|
| 1229 |
+
|
| 1230 |
+
if debug:
|
| 1231 |
+
debug.log(f"Completed {cleanup_type}", category="success")
|
webui.bat
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
|
| 3 |
+
chcp 65001
|
| 4 |
+
set PYTHONUTF8=1
|
| 5 |
+
:: The original source of the webui.bat file is stable-diffusion-webui
|
| 6 |
+
:: Modified and enhanced by Gemini with features for venv management and requirements handling.
|
| 7 |
+
|
| 8 |
+
:: --------- Configuration ---------
|
| 9 |
+
set COMMANDLINE_ARGS=
|
| 10 |
+
|
| 11 |
+
:: Define the application directory (folder name)
|
| 12 |
+
:: Leave empty if the app is in the root directory.
|
| 13 |
+
set APP_DIR=
|
| 14 |
+
|
| 15 |
+
:: Define the name of the Launch application
|
| 16 |
+
set APPLICATION_NAME=app.py
|
| 17 |
+
|
| 18 |
+
:: Define the requirements filename, default is requirements.txt
|
| 19 |
+
set REQUIREMENTS_FILE=requirements.txt
|
| 20 |
+
|
| 21 |
+
:: Define the name of the virtual environment directory
|
| 22 |
+
set VENV_NAME=venv
|
| 23 |
+
|
| 24 |
+
:: Set to 1 to always attempt to update packages from requirements.txt on every launch
|
| 25 |
+
set ALWAYS_UPDATE_REQS=1
|
| 26 |
+
:: ---------------------------------
|
| 27 |
+
|
| 28 |
+
:: --------- Path Setup Logic ---------
|
| 29 |
+
:: Logic to handle paths based on whether APP_DIR is set
|
| 30 |
+
if defined APP_DIR (
|
| 31 |
+
set "TARGET_REQ=%~dp0%APP_DIR%\%REQUIREMENTS_FILE%"
|
| 32 |
+
set "TARGET_SCRIPT=%~dp0%APP_DIR%\%APPLICATION_NAME%"
|
| 33 |
+
echo Working in subdirectory: %APP_DIR%
|
| 34 |
+
) else (
|
| 35 |
+
set "TARGET_REQ=%~dp0%REQUIREMENTS_FILE%"
|
| 36 |
+
set "TARGET_SCRIPT=%~dp0%APPLICATION_NAME%"
|
| 37 |
+
echo Working in root directory.
|
| 38 |
+
)
|
| 39 |
+
:: ------------------------------------
|
| 40 |
+
|
| 41 |
+
:: Set PYTHON executable if not already defined
|
| 42 |
+
if not defined PYTHON (set PYTHON=python)
|
| 43 |
+
:: Set VENV_DIR using VENV_NAME if not already defined
|
| 44 |
+
if not defined VENV_DIR (set "VENV_DIR=%~dp0%VENV_NAME%")
|
| 45 |
+
|
| 46 |
+
mkdir tmp 2>NUL
|
| 47 |
+
|
| 48 |
+
:: Check if Python is callable
|
| 49 |
+
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 50 |
+
if %ERRORLEVEL% == 0 goto :check_pip
|
| 51 |
+
echo Couldn't launch python
|
| 52 |
+
goto :show_stdout_stderr
|
| 53 |
+
|
| 54 |
+
:check_pip
|
| 55 |
+
:: Check if pip is available
|
| 56 |
+
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
| 57 |
+
if %ERRORLEVEL% == 0 goto :start_venv
|
| 58 |
+
:: If pip is not available and PIP_INSTALLER_LOCATION is set, try to install pip
|
| 59 |
+
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
| 60 |
+
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 61 |
+
if %ERRORLEVEL% == 0 goto :start_venv
|
| 62 |
+
echo Couldn't install pip
|
| 63 |
+
goto :show_stdout_stderr
|
| 64 |
+
|
| 65 |
+
:start_venv
|
| 66 |
+
:: Skip venv creation/activation if VENV_DIR is explicitly set to "-"
|
| 67 |
+
if ["%VENV_DIR%"] == ["-"] goto :skip_venv_entirely
|
| 68 |
+
:: Skip venv creation/activation if SKIP_VENV is set to "1"
|
| 69 |
+
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv_entirely
|
| 70 |
+
|
| 71 |
+
:: Check if the venv already exists by looking for Python.exe in its Scripts directory
|
| 72 |
+
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 73 |
+
if %ERRORLEVEL% == 0 goto :activate_venv_and_maybe_update
|
| 74 |
+
|
| 75 |
+
:: Venv does not exist, create it
|
| 76 |
+
echo Virtual environment not found in "%VENV_DIR%". Creating a new one.
|
| 77 |
+
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
|
| 78 |
+
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
|
| 79 |
+
%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 80 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 81 |
+
echo Unable to create venv in directory "%VENV_DIR%"
|
| 82 |
+
goto :show_stdout_stderr
|
| 83 |
+
)
|
| 84 |
+
echo Venv created.
|
| 85 |
+
|
| 86 |
+
:: Install requirements for the first time if venv was just created
|
| 87 |
+
:: This section handles the initial installation of packages from requirements.txt
|
| 88 |
+
:: immediately after a new virtual environment is created.
|
| 89 |
+
echo Checking for %REQUIREMENTS_FILE% for initial setup...
|
| 90 |
+
if exist "%TARGET_REQ%" (
|
| 91 |
+
echo Found %REQUIREMENTS_FILE% at "%TARGET_REQ%", attempting to install for initial setup...
|
| 92 |
+
call "%VENV_DIR%\Scripts\activate.bat"
|
| 93 |
+
echo Installing packages from %REQUIREMENTS_FILE% ^(initial setup^)...
|
| 94 |
+
"%VENV_DIR%\Scripts\python.exe" -m pip install -r "%TARGET_REQ%"
|
| 95 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 96 |
+
echo Failed to install requirements during initial setup. Please check the output above.
|
| 97 |
+
pause
|
| 98 |
+
goto :show_stdout_stderr_custom_pip_initial
|
| 99 |
+
)
|
| 100 |
+
echo Initial requirements installed successfully.
|
| 101 |
+
call "%VENV_DIR%\Scripts\deactivate.bat"
|
| 102 |
+
) else (
|
| 103 |
+
echo No %REQUIREMENTS_FILE% found at "%TARGET_REQ%", skipping package installation.
|
| 104 |
+
)
|
| 105 |
+
goto :activate_venv_and_maybe_update
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
:activate_venv_and_maybe_update
|
| 109 |
+
:: This label is reached if the venv exists or was just created.
|
| 110 |
+
:: Set PYTHON to point to the venv's Python interpreter.
|
| 111 |
+
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
|
| 112 |
+
echo Activating venv: %PYTHON%
|
| 113 |
+
|
| 114 |
+
:: Always update requirements if ALWAYS_UPDATE_REQS is 1
|
| 115 |
+
:: This section allows for updating packages from requirements.txt on every launch
|
| 116 |
+
:: if the ALWAYS_UPDATE_REQS variable is set to 1.
|
| 117 |
+
if defined ALWAYS_UPDATE_REQS (
|
| 118 |
+
if "%ALWAYS_UPDATE_REQS%"=="1" (
|
| 119 |
+
echo ALWAYS_UPDATE_REQS is enabled.
|
| 120 |
+
if exist "%TARGET_REQ%" (
|
| 121 |
+
echo Attempting to update packages from "%TARGET_REQ%"...
|
| 122 |
+
REM No need to call activate.bat here again, PYTHON is already set to the venv's python
|
| 123 |
+
%PYTHON% -m pip install -r "%TARGET_REQ%"
|
| 124 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 125 |
+
echo Failed to update requirements. Please check the output above.
|
| 126 |
+
pause
|
| 127 |
+
goto :endofscript
|
| 128 |
+
)
|
| 129 |
+
echo Requirements updated successfully.
|
| 130 |
+
) else (
|
| 131 |
+
echo ALWAYS_UPDATE_REQS is enabled, but no %REQUIREMENTS_FILE% found. Skipping update.
|
| 132 |
+
)
|
| 133 |
+
) else (
|
| 134 |
+
echo ALWAYS_UPDATE_REQS is not enabled or not set to 1. Skipping routine update.
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
goto :launch
|
| 139 |
+
|
| 140 |
+
:skip_venv_entirely
|
| 141 |
+
:: This label is reached if venv usage is explicitly skipped.
|
| 142 |
+
echo Skipping venv.
|
| 143 |
+
goto :launch
|
| 144 |
+
|
| 145 |
+
:launch
|
| 146 |
+
:: Launch the main application
|
| 147 |
+
echo Launching Web UI with arguments: %COMMANDLINE_ARGS% %*
|
| 148 |
+
echo Script path: %TARGET_SCRIPT%
|
| 149 |
+
%PYTHON% "%TARGET_SCRIPT%" %COMMANDLINE_ARGS% %*
|
| 150 |
+
echo Launch finished.
|
| 151 |
+
pause
|
| 152 |
+
exit /b
|
| 153 |
+
|
| 154 |
+
:show_stdout_stderr_custom_pip_initial
|
| 155 |
+
:: Custom error handler for failures during the initial pip install process.
|
| 156 |
+
echo.
|
| 157 |
+
echo exit code ^(pip initial install^): %errorlevel%
|
| 158 |
+
echo Errors during initial pip install. See output above.
|
| 159 |
+
echo.
|
| 160 |
+
echo Launch unsuccessful. Exiting.
|
| 161 |
+
pause
|
| 162 |
+
exit /b
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
:show_stdout_stderr
|
| 166 |
+
:: General error handler: displays stdout and stderr from the tmp directory.
|
| 167 |
+
echo.
|
| 168 |
+
echo exit code: %errorlevel%
|
| 169 |
+
|
| 170 |
+
for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
|
| 171 |
+
if %size% equ 0 goto :show_stderr
|
| 172 |
+
echo.
|
| 173 |
+
echo stdout:
|
| 174 |
+
type tmp\stdout.txt
|
| 175 |
+
|
| 176 |
+
:show_stderr
|
| 177 |
+
for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
|
| 178 |
+
if %size% equ 0 goto :endofscript
|
| 179 |
+
echo.
|
| 180 |
+
echo stderr:
|
| 181 |
+
type tmp\stderr.txt
|
| 182 |
+
|
| 183 |
+
:endofscript
|
| 184 |
+
echo.
|
| 185 |
+
echo Launch unsuccessful. Exiting.
|
| 186 |
+
pause
|
| 187 |
+
exit /b
|