File size: 4,234 Bytes
7011a64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'

import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
from typing import Dict, Tuple

MODEL_CHECKPOINTS = {
    "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth",
    "efficientnet_b0": "checkpoints/effnet_b0_best.pth",
    "efficientnet_b3": "checkpoints/effnet_b3_best.pth",
    "vit_b_16": "checkpoints/vit_b_16_best.pth"
}
DEFAULT_MODEL_NAME = "vit_b_16"

MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HFConvNeXtWrapper(nn.Module):
    def __init__(self, model_name, num_labels):
        super(HFConvNeXtWrapper, self).__init__()
        self.model = ConvNextV2ForImageClassification.from_pretrained(
            model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
    def forward(self, x):
        return self.model(x).logits

def get_model(model_name: str, num_classes: int) -> nn.Module:
    model = None
    if model_name == "efficientnet_b0":
        model = models.efficientnet_b0(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "efficientnet_b3":
        model = models.efficientnet_b3(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=None)
        num_ftrs = model.heads.head.in_features
        model.heads.head = nn.Linear(num_ftrs, num_classes)
    elif "convnextv2" in model_name:
        model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
    else:
        raise ValueError(f"Model '{model_name}' not supported.")
    return model

def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_name_from_ckpt = checkpoint['model_name']
    model = get_model(model_name_from_ckpt, num_classes=1)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()
    return model, {}

print("--- Loading all models into memory ---")
for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
    if os.path.exists(ckpt_path):
        model, _ = load_checkpoint(ckpt_path, DEVICE)
        MODELS[display_name] = model
        print(f"Loaded '{display_name}' on {DEVICE}.")
    else:
        print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")

if not MODELS:
    raise RuntimeError("No models were loaded. Please check your checkpoints directory.")

with open('cm_config.yaml', 'r') as f:
    config = yaml.safe_load(f)
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(pil_image, model_name: str):
    if pil_image is None: return None
    
    model = MODELS[model_name]
    pil_image = pil_image.convert("RGB")
    image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        output = model(image_tensor)
        prob = torch.sigmoid(output).item()
        
    return {"clean": 1 - prob, "messy": prob}

iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=list(MODELS.keys()),
            value=DEFAULT_MODEL_NAME,
            label="Select Model"
        )
    ],
    outputs=gr.Label(num_top_classes=2, label="Predictions"),
    title="Messy vs Clean Image Classifier",
    description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.",
)

iface.launch()