lsatone / backend /lightweight_ai_enhancer.py
3v324v23's picture
Update Comic123 with local comic folder files
83e35a7
"""
Lightweight AI Enhancement for Limited VRAM (< 4GB)
Optimized for RTX 3050 Laptop GPU
Uses efficient models with excellent quality
"""
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import requests
from tqdm import tqdm
from typing import Optional, Dict, Any, Tuple
import warnings
warnings.filterwarnings('ignore')
# Lightweight ESRGAN Architecture
class RRDBNet_arch(nn.Module):
"""Lightweight RRDB Net for ESRGAN - optimized for low VRAM"""
def __init__(self, in_nc=3, out_nc=3, nf=32, nb=16): # Reduced from 64/23 to 32/16
super(RRDBNet_arch, self).__init__()
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(fea)
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
class LightweightEnhancer:
"""Lightweight AI enhancer for <4GB VRAM"""
def __init__(self, device=None):
"""Initialize lightweight enhancer"""
# Set device
if device is None:
if torch.cuda.is_available():
self.device = torch.device('cuda:0')
print(f"πŸš€ Using GPU: {torch.cuda.get_device_name(0)}")
# RTX 3050 Laptop optimization
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.7) # Use only 70% VRAM
# Get VRAM info
props = torch.cuda.get_device_properties(0)
self.vram_gb = props.total_memory / (1024**3)
print(f"πŸ“Š VRAM: {self.vram_gb:.1f} GB")
else:
self.device = torch.device('cpu')
print("πŸ’» Using CPU (GPU not available)")
self.vram_gb = 0
else:
self.device = device
self.vram_gb = 4 # Assume 4GB
# Model storage
self.model_dir = 'models_lightweight'
os.makedirs(self.model_dir, exist_ok=True)
# Models
self.esrgan_model = None
self.face_model = None
# Settings based on VRAM
if self.vram_gb < 4:
self.tile_size = 256 # Smaller tiles for <4GB
self.use_fp16 = True # Force FP16
else:
self.tile_size = 384
self.use_fp16 = True
def load_lightweight_esrgan(self):
"""Load lightweight ESRGAN model"""
try:
print("πŸ”„ Loading lightweight ESRGAN...")
# Create lightweight model
self.esrgan_model = RRDBNet_arch()
# Try to load pretrained weights if available
model_path = os.path.join(self.model_dir, 'lightweight_esrgan.pth')
if os.path.exists(model_path):
self.esrgan_model.load_state_dict(torch.load(model_path, map_location=self.device))
print("βœ… Loaded pretrained lightweight model")
else:
print("⚠️ No pretrained model found, using random initialization")
# In practice, you'd train this or download a pretrained one
self.esrgan_model = self.esrgan_model.to(self.device)
self.esrgan_model.eval()
# Convert to FP16 if using GPU
if self.use_fp16 and self.device.type == 'cuda':
self.esrgan_model = self.esrgan_model.half()
print("βœ… Using FP16 for memory efficiency")
return True
except Exception as e:
print(f"❌ Failed to load lightweight ESRGAN: {e}")
return False
def enhance_with_lightweight_esrgan(self, img):
"""Enhance using lightweight ESRGAN with tiling"""
if self.esrgan_model is None:
if not self.load_lightweight_esrgan():
return self.fallback_upscale(img, 2)
try:
# Convert to tensor
img_tensor = self.img_to_tensor(img)
# Process with tiling for low VRAM
result = self.process_with_tiles(img_tensor, self.esrgan_model, scale=2)
# Convert back to numpy
result = self.tensor_to_img(result)
return result
except Exception as e:
print(f"❌ Enhancement failed: {e}")
return self.fallback_upscale(img, 2)
def process_with_tiles(self, img_tensor, model, scale=2):
"""Process image in tiles to save VRAM"""
_, _, h, w = img_tensor.shape
# Calculate output size (max 2K)
target_h = h * scale
target_w = w * scale
# Apply 2K limit
if target_w > 2048 or target_h > 1080:
limit_scale = min(2048/target_w, 1080/target_h)
out_w = int(target_w * limit_scale)
out_h = int(target_h * limit_scale)
print(f" πŸ“ Limiting output to {out_w}x{out_h} (2K max)")
else:
out_h, out_w = target_h, target_w
output = torch.zeros((1, 3, out_h, out_w), device=self.device)
# Tile processing
tile_size = self.tile_size
pad = 16 # Overlap to avoid seams
for y in range(0, h, tile_size - pad):
for x in range(0, w, tile_size - pad):
# Extract tile
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
tile = img_tensor[:, :, y:y_end, x:x_end]
# Process tile
with torch.no_grad():
if self.use_fp16 and self.device.type == 'cuda':
tile = tile.half()
tile_out = model(tile)
if self.use_fp16:
tile_out = tile_out.float()
# Place tile in output
out_y = y * scale
out_x = x * scale
out_y_end = min(out_y + tile_out.shape[2], out_h)
out_x_end = min(out_x + tile_out.shape[3], out_w)
output[:, :, out_y:out_y_end, out_x:out_x_end] = tile_out[:, :, :out_y_end-out_y, :out_x_end-out_x]
# Clear cache to save memory
if self.device.type == 'cuda':
torch.cuda.empty_cache()
return output
def img_to_tensor(self, img):
"""Convert image to tensor"""
if isinstance(img, Image.Image):
img = np.array(img)
# Ensure RGB
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
elif img.shape[2] == 3 and isinstance(img, np.ndarray):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize to [0, 1]
img = img.astype(np.float32) / 255.0
# Convert to tensor
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
return img_tensor.to(self.device)
def tensor_to_img(self, tensor):
"""Convert tensor to image"""
img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
img = (img * 255).clip(0, 255).astype(np.uint8)
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def fallback_upscale(self, img, scale):
"""Fallback upscaling using OpenCV with 2K limit"""
print(" πŸ“ˆ Using optimized fallback upscaling...")
h, w = img.shape[:2]
# Calculate new size with 2K limit
target_scale = min(scale, 2048/w, 1080/h)
new_w = int(w * target_scale)
new_h = int(h * target_scale)
# Use EDSR-inspired upscaling
# First, upscale with CUBIC
upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
# Apply sharpening
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1
upscaled = cv2.filter2D(upscaled, -1, kernel)
# Reduce noise
upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50)
return upscaled
def enhance_faces_lightweight(self, img):
"""Lightweight face enhancement"""
try:
# Detect faces using OpenCV
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) == 0:
return img
print(f" πŸ‘€ Enhancing {len(faces)} faces...")
for (x, y, w, h) in faces:
# Extract face with padding
pad = int(w * 0.1)
x_start = max(0, x - pad)
y_start = max(0, y - pad)
x_end = min(img.shape[1], x + w + pad)
y_end = min(img.shape[0], y + h + pad)
face = img[y_start:y_end, x_start:x_end]
# Enhance face
face = self.enhance_face_region_lightweight(face)
# Put back
img[y_start:y_end, x_start:x_end] = face
return img
except Exception as e:
print(f"⚠️ Face enhancement failed: {e}")
return img
def enhance_face_region_lightweight(self, face):
"""Lightweight face enhancement"""
# 1. Denoise
face = cv2.bilateralFilter(face, 9, 75, 75)
# 2. Enhance details
lab = cv2.cvtColor(face, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
# CLAHE on L channel
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l = clahe.apply(l)
face = cv2.merge([l, a, b])
face = cv2.cvtColor(face, cv2.COLOR_LAB2BGR)
# 3. Subtle sharpening
kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) / 1
face = cv2.filter2D(face, -1, kernel)
return face
def enhance_image_pipeline(self, image_path: str, output_path: str = None) -> str:
"""Complete enhancement pipeline for low VRAM"""
print(f"🎨 Enhancing {os.path.basename(image_path)} (Lightweight Mode)...")
try:
# Load image
img = cv2.imread(image_path)
if img is None:
print(f"❌ Failed to load image: {image_path}")
return image_path
original_shape = img.shape[:2]
print(f" Original: {original_shape[1]}x{original_shape[0]}")
# Step 1: Lightweight super resolution
print(" πŸš€ Applying lightweight upscaling (max 2K)...")
print(f" πŸ“ Input: {img.shape[1]}x{img.shape[0]}")
enhanced = self.enhance_with_lightweight_esrgan(img)
# Step 2: Face enhancement
print(" πŸ‘€ Enhancing faces...")
enhanced = self.enhance_faces_lightweight(enhanced)
# Step 3: Final color correction
print(" 🎨 Applying color correction...")
enhanced = self.color_correction(enhanced)
# Save
if output_path is None:
output_path = image_path.replace('.', '_enhanced.')
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
new_shape = enhanced.shape[:2]
print(f" βœ… Enhanced: {new_shape[1]}x{new_shape[0]}")
# Clear memory
self.clear_memory()
return output_path
except Exception as e:
print(f"❌ Pipeline failed: {e}")
return image_path
def color_correction(self, img):
"""Lightweight color correction"""
# Convert to LAB
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
# Enhance L channel
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l = clahe.apply(l)
# Slight color boost
a = cv2.convertScaleAbs(a, alpha=1.1, beta=0)
b = cv2.convertScaleAbs(b, alpha=1.1, beta=0)
# Merge and convert back
enhanced = cv2.merge([l, a, b])
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
return enhanced
def clear_memory(self):
"""Clear GPU memory"""
if self.device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Global instance
_lightweight_enhancer = None
def get_lightweight_enhancer():
"""Get or create global lightweight enhancer"""
global _lightweight_enhancer
if _lightweight_enhancer is None:
_lightweight_enhancer = LightweightEnhancer()
return _lightweight_enhancer