Spaces:
Sleeping
Sleeping
File size: 4,234 Bytes
7011a64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import os
os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
from typing import Dict, Tuple
MODEL_CHECKPOINTS = {
"convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth",
"efficientnet_b0": "checkpoints/effnet_b0_best.pth",
"efficientnet_b3": "checkpoints/effnet_b3_best.pth",
"vit_b_16": "checkpoints/vit_b_16_best.pth"
}
DEFAULT_MODEL_NAME = "vit_b_16"
MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name: str, num_classes: int) -> nn.Module:
model = None
if model_name == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "efficientnet_b3":
model = models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
else:
raise ValueError(f"Model '{model_name}' not supported.")
return model
def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_name_from_ckpt = checkpoint['model_name']
model = get_model(model_name_from_ckpt, num_classes=1)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
return model, {}
print("--- Loading all models into memory ---")
for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
if os.path.exists(ckpt_path):
model, _ = load_checkpoint(ckpt_path, DEVICE)
MODELS[display_name] = model
print(f"Loaded '{display_name}' on {DEVICE}.")
else:
print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")
if not MODELS:
raise RuntimeError("No models were loaded. Please check your checkpoints directory.")
with open('cm_config.yaml', 'r') as f:
config = yaml.safe_load(f)
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image, model_name: str):
if pil_image is None: return None
model = MODELS[model_name]
pil_image = pil_image.convert("RGB")
image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(image_tensor)
prob = torch.sigmoid(output).item()
return {"clean": 1 - prob, "messy": prob}
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_NAME,
label="Select Model"
)
],
outputs=gr.Label(num_top_classes=2, label="Predictions"),
title="Messy vs Clean Image Classifier",
description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.",
)
iface.launch() |