TerenceG commited on
Commit
da2daf9
·
verified ·
1 Parent(s): 24232c5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ from transformers import AutoModel, AutoFeatureExtractor
8
+ import timm
9
+ import numpy as np
10
+ import json
11
+ import base64
12
+ from io import BytesIO
13
+
14
+ class AIDetectionGradCAM:
15
+ def __init__(self):
16
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ self.models = {}
18
+ self.feature_extractors = {}
19
+ self.target_layers = {}
20
+
21
+ # Initialiser les modèles
22
+ self._load_models()
23
+
24
+ def _load_models(self):
25
+ """Charge les modèles pour la détection"""
26
+ try:
27
+ # Modèle Swin Transformer
28
+ model_name = "microsoft/swin-base-patch4-window7-224-in22k"
29
+ self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2)
30
+ self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name)
31
+
32
+ # Définir les couches cibles pour GradCAM
33
+ self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1]
34
+
35
+ # Mettre en mode évaluation
36
+ for model in self.models.values():
37
+ model.eval()
38
+ model.to(self.device)
39
+
40
+ except Exception as e:
41
+ print(f"Erreur lors du chargement des modèles: {e}")
42
+
43
+ def _preprocess_image(self, image, model_type='swin'):
44
+ """Prétraite l'image pour le modèle"""
45
+ if isinstance(image, str):
46
+ # Si c'est un chemin ou base64
47
+ if image.startswith('data:image'):
48
+ # Décoder base64
49
+ header, data = image.split(',', 1)
50
+ image_data = base64.b64decode(data)
51
+ image = Image.open(BytesIO(image_data))
52
+ else:
53
+ image = Image.open(image)
54
+
55
+ # Convertir en RGB si nécessaire
56
+ if image.mode != 'RGB':
57
+ image = image.convert('RGB')
58
+
59
+ # Redimensionner
60
+ image = image.resize((224, 224))
61
+
62
+ # Normalisation standard
63
+ transform = transforms.Compose([
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
66
+ std=[0.229, 0.224, 0.225])
67
+ ])
68
+
69
+ tensor = transform(image).unsqueeze(0).to(self.device)
70
+ return tensor, np.array(image) / 255.0
71
+
72
+ def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'):
73
+ """Génère la carte de saillance GradCAM"""
74
+ try:
75
+ model = self.models[model_type]
76
+ target_layers = self.target_layers[model_type]
77
+
78
+ # Créer l'objet GradCAM
79
+ cam = GradCAM(model=model, target_layers=target_layers)
80
+
81
+ # Générer la carte de saillance
82
+ grayscale_cam = cam(input_tensor=image_tensor, targets=None)
83
+ grayscale_cam = grayscale_cam[0, :]
84
+
85
+ # Superposer sur l'image originale
86
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
87
+
88
+ return cam_image
89
+
90
+ except Exception as e:
91
+ print(f"Erreur GradCAM: {e}")
92
+ return rgb_img * 255
93
+
94
+ def predict_and_explain(self, image):
95
+ """Prédiction avec explication GradCAM"""
96
+ try:
97
+ # Prétraitement
98
+ image_tensor, rgb_img = self._preprocess_image(image)
99
+
100
+ # Prédiction
101
+ with torch.no_grad():
102
+ outputs = self.models['swin'](image_tensor)
103
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
104
+ confidence = probabilities.max().item()
105
+ prediction = probabilities.argmax().item()
106
+
107
+ # Génération GradCAM
108
+ cam_image = self._generate_gradcam(image_tensor, rgb_img)
109
+
110
+ # Résultats
111
+ class_names = ['Real', 'AI-Generated']
112
+ predicted_class = class_names[prediction]
113
+
114
+ result = {
115
+ 'prediction': prediction,
116
+ 'confidence': confidence,
117
+ 'predicted_class': predicted_class,
118
+ 'probabilities': {
119
+ 'Real': probabilities[0][0].item(),
120
+ 'AI-Generated': probabilities[0][1].item()
121
+ }
122
+ }
123
+
124
+ return cam_image.astype(np.uint8), result
125
+
126
+ except Exception as e:
127
+ return image, {'error': str(e)}
128
+
129
+ # Initialiser le détecteur
130
+ detector = AIDetectionGradCAM()
131
+
132
+ def analyze_image(image):
133
+ """Fonction pour l'interface Gradio"""
134
+ if image is None:
135
+ return None, "Veuillez télécharger une image"
136
+
137
+ try:
138
+ cam_image, result = detector.predict_and_explain(image)
139
+
140
+ if 'error' in result:
141
+ return image, f"Erreur: {result['error']}"
142
+
143
+ # Formatage du résultat
144
+ confidence_percent = result['confidence'] * 100
145
+ predicted_class = result['predicted_class']
146
+
147
+ analysis_text = f"""
148
+ ## 🔍 Analyse de l'image
149
+
150
+ **Prédiction:** {predicted_class}
151
+ **Confiance:** {confidence_percent:.1f}%
152
+
153
+ **Probabilités détaillées:**
154
+ - Real: {result['probabilities']['Real']:.3f}
155
+ - AI-Generated: {result['probabilities']['AI-Generated']:.3f}
156
+
157
+ La carte de saillance (GradCAM) montre les zones que le modèle considère comme importantes pour sa décision.
158
+ """
159
+
160
+ return cam_image, analysis_text
161
+
162
+ except Exception as e:
163
+ return image, f"Erreur lors de l'analyse: {str(e)}"
164
+
165
+ # Interface Gradio
166
+ with gr.Blocks(theme=gr.themes.Soft(), title="VerifAI - Détection d'images IA avec GradCAM") as app:
167
+ gr.Markdown("""
168
+ # 🔍 VerifAI - Détecteur d'images IA avec GradCAM
169
+
170
+ Téléchargez une image pour déterminer si elle a été générée par une IA.
171
+ L'application utilise GradCAM pour expliquer visuellement sa décision.
172
+ """)
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_image = gr.Image(
177
+ type="pil",
178
+ label="📸 Téléchargez votre image",
179
+ height=400
180
+ )
181
+ analyze_btn = gr.Button("🔍 Analyser l'image", variant="primary", size="lg")
182
+
183
+ with gr.Column():
184
+ output_image = gr.Image(
185
+ label="🎯 Carte de saillance GradCAM",
186
+ height=400
187
+ )
188
+ result_text = gr.Markdown(label="📊 Résultats de l'analyse")
189
+
190
+ analyze_btn.click(
191
+ fn=analyze_image,
192
+ inputs=[input_image],
193
+ outputs=[output_image, result_text]
194
+ )
195
+
196
+ gr.Markdown("""
197
+ ---
198
+ ### 💡 Comment interpréter les résultats
199
+
200
+ - **Real**: L'image semble être une vraie photo
201
+ - **AI-Generated**: L'image semble être générée par IA
202
+ - **Carte de saillance**: Les zones colorées indiquent les régions importantes pour la décision
203
+ """)
204
+
205
+ if __name__ == "__main__":
206
+ app.launch()