|
|
""" |
|
|
Lightweight AI Enhancement for Limited VRAM (< 4GB) |
|
|
Optimized for RTX 3050 Laptop GPU |
|
|
Uses efficient models with excellent quality |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
from typing import Optional, Dict, Any, Tuple |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
class RRDBNet_arch(nn.Module): |
|
|
"""Lightweight RRDB Net for ESRGAN - optimized for low VRAM""" |
|
|
def __init__(self, in_nc=3, out_nc=3, nf=32, nb=16): |
|
|
super(RRDBNet_arch, self).__init__() |
|
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
|
|
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
fea = self.conv_first(x) |
|
|
trunk = self.trunk_conv(fea) |
|
|
fea = fea + trunk |
|
|
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
|
|
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
|
|
out = self.conv_last(self.lrelu(self.HRconv(fea))) |
|
|
return out |
|
|
|
|
|
class LightweightEnhancer: |
|
|
"""Lightweight AI enhancer for <4GB VRAM""" |
|
|
|
|
|
def __init__(self, device=None): |
|
|
"""Initialize lightweight enhancer""" |
|
|
|
|
|
|
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device('cuda:0') |
|
|
print(f"π Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.cuda.set_per_process_memory_fraction(0.7) |
|
|
|
|
|
|
|
|
props = torch.cuda.get_device_properties(0) |
|
|
self.vram_gb = props.total_memory / (1024**3) |
|
|
print(f"π VRAM: {self.vram_gb:.1f} GB") |
|
|
|
|
|
else: |
|
|
self.device = torch.device('cpu') |
|
|
print("π» Using CPU (GPU not available)") |
|
|
self.vram_gb = 0 |
|
|
else: |
|
|
self.device = device |
|
|
self.vram_gb = 4 |
|
|
|
|
|
|
|
|
self.model_dir = 'models_lightweight' |
|
|
os.makedirs(self.model_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.esrgan_model = None |
|
|
self.face_model = None |
|
|
|
|
|
|
|
|
if self.vram_gb < 4: |
|
|
self.tile_size = 256 |
|
|
self.use_fp16 = True |
|
|
else: |
|
|
self.tile_size = 384 |
|
|
self.use_fp16 = True |
|
|
|
|
|
def load_lightweight_esrgan(self): |
|
|
"""Load lightweight ESRGAN model""" |
|
|
try: |
|
|
print("π Loading lightweight ESRGAN...") |
|
|
|
|
|
|
|
|
self.esrgan_model = RRDBNet_arch() |
|
|
|
|
|
|
|
|
model_path = os.path.join(self.model_dir, 'lightweight_esrgan.pth') |
|
|
if os.path.exists(model_path): |
|
|
self.esrgan_model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
|
print("β
Loaded pretrained lightweight model") |
|
|
else: |
|
|
print("β οΈ No pretrained model found, using random initialization") |
|
|
|
|
|
|
|
|
self.esrgan_model = self.esrgan_model.to(self.device) |
|
|
self.esrgan_model.eval() |
|
|
|
|
|
|
|
|
if self.use_fp16 and self.device.type == 'cuda': |
|
|
self.esrgan_model = self.esrgan_model.half() |
|
|
print("β
Using FP16 for memory efficiency") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to load lightweight ESRGAN: {e}") |
|
|
return False |
|
|
|
|
|
def enhance_with_lightweight_esrgan(self, img): |
|
|
"""Enhance using lightweight ESRGAN with tiling""" |
|
|
if self.esrgan_model is None: |
|
|
if not self.load_lightweight_esrgan(): |
|
|
return self.fallback_upscale(img, 2) |
|
|
|
|
|
try: |
|
|
|
|
|
img_tensor = self.img_to_tensor(img) |
|
|
|
|
|
|
|
|
result = self.process_with_tiles(img_tensor, self.esrgan_model, scale=2) |
|
|
|
|
|
|
|
|
result = self.tensor_to_img(result) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Enhancement failed: {e}") |
|
|
return self.fallback_upscale(img, 2) |
|
|
|
|
|
def process_with_tiles(self, img_tensor, model, scale=2): |
|
|
"""Process image in tiles to save VRAM""" |
|
|
_, _, h, w = img_tensor.shape |
|
|
|
|
|
|
|
|
target_h = h * scale |
|
|
target_w = w * scale |
|
|
|
|
|
|
|
|
if target_w > 2048 or target_h > 1080: |
|
|
limit_scale = min(2048/target_w, 1080/target_h) |
|
|
out_w = int(target_w * limit_scale) |
|
|
out_h = int(target_h * limit_scale) |
|
|
print(f" π Limiting output to {out_w}x{out_h} (2K max)") |
|
|
else: |
|
|
out_h, out_w = target_h, target_w |
|
|
output = torch.zeros((1, 3, out_h, out_w), device=self.device) |
|
|
|
|
|
|
|
|
tile_size = self.tile_size |
|
|
pad = 16 |
|
|
|
|
|
for y in range(0, h, tile_size - pad): |
|
|
for x in range(0, w, tile_size - pad): |
|
|
|
|
|
y_end = min(y + tile_size, h) |
|
|
x_end = min(x + tile_size, w) |
|
|
tile = img_tensor[:, :, y:y_end, x:x_end] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.use_fp16 and self.device.type == 'cuda': |
|
|
tile = tile.half() |
|
|
|
|
|
tile_out = model(tile) |
|
|
|
|
|
if self.use_fp16: |
|
|
tile_out = tile_out.float() |
|
|
|
|
|
|
|
|
out_y = y * scale |
|
|
out_x = x * scale |
|
|
out_y_end = min(out_y + tile_out.shape[2], out_h) |
|
|
out_x_end = min(out_x + tile_out.shape[3], out_w) |
|
|
|
|
|
output[:, :, out_y:out_y_end, out_x:out_x_end] = tile_out[:, :, :out_y_end-out_y, :out_x_end-out_x] |
|
|
|
|
|
|
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return output |
|
|
|
|
|
def img_to_tensor(self, img): |
|
|
"""Convert image to tensor""" |
|
|
if isinstance(img, Image.Image): |
|
|
img = np.array(img) |
|
|
|
|
|
|
|
|
if len(img.shape) == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
elif img.shape[2] == 4: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
|
|
elif img.shape[2] == 3 and isinstance(img, np.ndarray): |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
img = img.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) |
|
|
|
|
|
return img_tensor.to(self.device) |
|
|
|
|
|
def tensor_to_img(self, tensor): |
|
|
"""Convert tensor to image""" |
|
|
img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
|
img = (img * 255).clip(0, 255).astype(np.uint8) |
|
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
def fallback_upscale(self, img, scale): |
|
|
"""Fallback upscaling using OpenCV with 2K limit""" |
|
|
print(" π Using optimized fallback upscaling...") |
|
|
|
|
|
h, w = img.shape[:2] |
|
|
|
|
|
|
|
|
target_scale = min(scale, 2048/w, 1080/h) |
|
|
new_w = int(w * target_scale) |
|
|
new_h = int(h * target_scale) |
|
|
|
|
|
|
|
|
|
|
|
upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
|
|
|
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 |
|
|
upscaled = cv2.filter2D(upscaled, -1, kernel) |
|
|
|
|
|
|
|
|
upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) |
|
|
|
|
|
return upscaled |
|
|
|
|
|
def enhance_faces_lightweight(self, img): |
|
|
"""Lightweight face enhancement""" |
|
|
try: |
|
|
|
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
|
|
|
if len(faces) == 0: |
|
|
return img |
|
|
|
|
|
print(f" π€ Enhancing {len(faces)} faces...") |
|
|
|
|
|
for (x, y, w, h) in faces: |
|
|
|
|
|
pad = int(w * 0.1) |
|
|
x_start = max(0, x - pad) |
|
|
y_start = max(0, y - pad) |
|
|
x_end = min(img.shape[1], x + w + pad) |
|
|
y_end = min(img.shape[0], y + h + pad) |
|
|
|
|
|
face = img[y_start:y_end, x_start:x_end] |
|
|
|
|
|
|
|
|
face = self.enhance_face_region_lightweight(face) |
|
|
|
|
|
|
|
|
img[y_start:y_end, x_start:x_end] = face |
|
|
|
|
|
return img |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Face enhancement failed: {e}") |
|
|
return img |
|
|
|
|
|
def enhance_face_region_lightweight(self, face): |
|
|
"""Lightweight face enhancement""" |
|
|
|
|
|
face = cv2.bilateralFilter(face, 9, 75, 75) |
|
|
|
|
|
|
|
|
lab = cv2.cvtColor(face, cv2.COLOR_BGR2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
|
l = clahe.apply(l) |
|
|
|
|
|
face = cv2.merge([l, a, b]) |
|
|
face = cv2.cvtColor(face, cv2.COLOR_LAB2BGR) |
|
|
|
|
|
|
|
|
kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) / 1 |
|
|
face = cv2.filter2D(face, -1, kernel) |
|
|
|
|
|
return face |
|
|
|
|
|
def enhance_image_pipeline(self, image_path: str, output_path: str = None) -> str: |
|
|
"""Complete enhancement pipeline for low VRAM""" |
|
|
print(f"π¨ Enhancing {os.path.basename(image_path)} (Lightweight Mode)...") |
|
|
|
|
|
try: |
|
|
|
|
|
img = cv2.imread(image_path) |
|
|
if img is None: |
|
|
print(f"β Failed to load image: {image_path}") |
|
|
return image_path |
|
|
|
|
|
original_shape = img.shape[:2] |
|
|
print(f" Original: {original_shape[1]}x{original_shape[0]}") |
|
|
|
|
|
|
|
|
print(" π Applying lightweight upscaling (max 2K)...") |
|
|
print(f" π Input: {img.shape[1]}x{img.shape[0]}") |
|
|
enhanced = self.enhance_with_lightweight_esrgan(img) |
|
|
|
|
|
|
|
|
print(" π€ Enhancing faces...") |
|
|
enhanced = self.enhance_faces_lightweight(enhanced) |
|
|
|
|
|
|
|
|
print(" π¨ Applying color correction...") |
|
|
enhanced = self.color_correction(enhanced) |
|
|
|
|
|
|
|
|
if output_path is None: |
|
|
output_path = image_path.replace('.', '_enhanced.') |
|
|
|
|
|
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
|
|
|
|
|
new_shape = enhanced.shape[:2] |
|
|
print(f" β
Enhanced: {new_shape[1]}x{new_shape[0]}") |
|
|
|
|
|
|
|
|
self.clear_memory() |
|
|
|
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Pipeline failed: {e}") |
|
|
return image_path |
|
|
|
|
|
def color_correction(self, img): |
|
|
"""Lightweight color correction""" |
|
|
|
|
|
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
|
l = clahe.apply(l) |
|
|
|
|
|
|
|
|
a = cv2.convertScaleAbs(a, alpha=1.1, beta=0) |
|
|
b = cv2.convertScaleAbs(b, alpha=1.1, beta=0) |
|
|
|
|
|
|
|
|
enhanced = cv2.merge([l, a, b]) |
|
|
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
def clear_memory(self): |
|
|
"""Clear GPU memory""" |
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
_lightweight_enhancer = None |
|
|
|
|
|
def get_lightweight_enhancer(): |
|
|
"""Get or create global lightweight enhancer""" |
|
|
global _lightweight_enhancer |
|
|
if _lightweight_enhancer is None: |
|
|
_lightweight_enhancer = LightweightEnhancer() |
|
|
return _lightweight_enhancer |