NSFW Image Classifier
This model classifies images as SFW (Safe For Work) or NSFW (Not Safe For Work).
Model Details
- Model Architecture: convnext_tiny.fb_in22k_ft_in1k
- Input Size: 224x224
- Best Validation Accuracy: 0.9736
- Training Epochs: 10
- Batch Size: 32
- Learning Rate: 0.0001
Usage
PyTorch
import torch
import timm
from PIL import Image
from torchvision import transforms
model = timm.create_model("convnext_tiny.fb_in22k_ft_in1k", pretrained=False, num_classes=2)
checkpoint = torch.load("pytorch_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open("image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)
output = model(input_tensor)
prediction = torch.softmax(output, dim=1)
label = "NSFW" if prediction[0][1] > 0.5 else "SFW"
ONNX Runtime
import onnxruntime as ort
import numpy as np
from PIL import Image
session = ort.InferenceSession("model.onnx")
# Preprocess image and run inference
Labels
- 0: SFW (Safe For Work)
- 1: NSFW (Not Safe For Work)
Training Dataset
- CaveduckAI/nsfw-sfw-img-dataset
License
MIT
- Downloads last month
- 9