Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from transformers import AutoModel, AutoFeatureExtractor | |
| import timm | |
| import numpy as np | |
| import json | |
| import base64 | |
| from io import BytesIO | |
| class AIDetectionGradCAM: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.feature_extractors = {} | |
| self.target_layers = {} | |
| # Initialiser les modèles | |
| self._load_models() | |
| def _load_models(self): | |
| """Charge les modèles pour la détection""" | |
| try: | |
| # Modèle Swin Transformer | |
| model_name = "microsoft/swin-base-patch4-window7-224-in22k" | |
| self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2) | |
| self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name) | |
| # Définir les couches cibles pour GradCAM | |
| self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1] | |
| # Mettre en mode évaluation | |
| for model in self.models.values(): | |
| model.eval() | |
| model.to(self.device) | |
| except Exception as e: | |
| print(f"Erreur lors du chargement des modèles: {e}") | |
| def _preprocess_image(self, image, model_type='swin'): | |
| """Prétraite l'image pour le modèle""" | |
| if isinstance(image, str): | |
| # Si c'est un chemin ou base64 | |
| if image.startswith('data:image'): | |
| # Décoder base64 | |
| header, data = image.split(',', 1) | |
| image_data = base64.b64decode(data) | |
| image = Image.open(BytesIO(image_data)) | |
| else: | |
| image = Image.open(image) | |
| # Convertir en RGB si nécessaire | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Redimensionner | |
| image = image.resize((224, 224)) | |
| # Normalisation standard | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| tensor = transform(image).unsqueeze(0).to(self.device) | |
| return tensor, np.array(image) / 255.0 | |
| def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'): | |
| """Génère la carte de saillance GradCAM""" | |
| try: | |
| model = self.models[model_type] | |
| target_layers = self.target_layers[model_type] | |
| # Créer l'objet GradCAM | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| # Générer la carte de saillance | |
| grayscale_cam = cam(input_tensor=image_tensor, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # Superposer sur l'image originale | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| return cam_image | |
| except Exception as e: | |
| print(f"Erreur GradCAM: {e}") | |
| return rgb_img * 255 | |
| def predict_and_explain(self, image): | |
| """Prédiction avec explication GradCAM""" | |
| try: | |
| # Prétraitement | |
| image_tensor, rgb_img = self._preprocess_image(image) | |
| # Prédiction | |
| with torch.no_grad(): | |
| outputs = self.models['swin'](image_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| confidence = probabilities.max().item() | |
| prediction = probabilities.argmax().item() | |
| # Génération GradCAM | |
| cam_image = self._generate_gradcam(image_tensor, rgb_img) | |
| # Résultats | |
| class_names = ['Real', 'AI-Generated'] | |
| predicted_class = class_names[prediction] | |
| result = { | |
| 'prediction': prediction, | |
| 'confidence': confidence, | |
| 'predicted_class': predicted_class, | |
| 'probabilities': { | |
| 'Real': probabilities[0][0].item(), | |
| 'AI-Generated': probabilities[0][1].item() | |
| } | |
| } | |
| return cam_image.astype(np.uint8), result | |
| except Exception as e: | |
| return image, {'error': str(e)} | |
| # Initialiser le détecteur | |
| detector = AIDetectionGradCAM() | |
| def analyze_image(image): | |
| """Fonction pour l'interface Gradio""" | |
| if image is None: | |
| return None, "Veuillez télécharger une image" | |
| try: | |
| cam_image, result = detector.predict_and_explain(image) | |
| if 'error' in result: | |
| return image, f"Erreur: {result['error']}" | |
| # Formatage du résultat | |
| confidence_percent = result['confidence'] * 100 | |
| predicted_class = result['predicted_class'] | |
| analysis_text = f""" | |
| ## 🔍 Analyse de l'image | |
| **Prédiction:** {predicted_class} | |
| **Confiance:** {confidence_percent:.1f}% | |
| **Probabilités détaillées:** | |
| - Real: {result['probabilities']['Real']:.3f} | |
| - AI-Generated: {result['probabilities']['AI-Generated']:.3f} | |
| La carte de saillance (GradCAM) montre les zones que le modèle considère comme importantes pour sa décision. | |
| """ | |
| return cam_image, analysis_text | |
| except Exception as e: | |
| return image, f"Erreur lors de l'analyse: {str(e)}" | |
| # Interface Gradio | |
| with gr.Blocks(theme=gr.themes.Soft(), title="VerifAI - Détection d'images IA avec GradCAM") as app: | |
| gr.Markdown(""" | |
| # 🔍 VerifAI - Détecteur d'images IA avec GradCAM | |
| Téléchargez une image pour déterminer si elle a été générée par une IA. | |
| L'application utilise GradCAM pour expliquer visuellement sa décision. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="📸 Téléchargez votre image", | |
| height=400 | |
| ) | |
| analyze_btn = gr.Button("🔍 Analyser l'image", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="🎯 Carte de saillance GradCAM", | |
| height=400 | |
| ) | |
| result_text = gr.Markdown(label="📊 Résultats de l'analyse") | |
| analyze_btn.click( | |
| fn=analyze_image, | |
| inputs=[input_image], | |
| outputs=[output_image, result_text] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 💡 Comment interpréter les résultats | |
| - **Real**: L'image semble être une vraie photo | |
| - **AI-Generated**: L'image semble être générée par IA | |
| - **Carte de saillance**: Les zones colorées indiquent les régions importantes pour la décision | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |