--- license: apache-2.0 base_model: microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft model-index: - name: THW results: - task: name: Image Classification type: image-classification dataset: name: None type: None config: None split: None args: None metrics: - name: None type: None value: None --- # Normal1919/THW This model is a fine-tuned version of [microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft](https://huggingface.co/microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft) on the private dataset. # How to use ```python import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from transformers import AutoModelForImageClassification from matplotlib import pyplot as plt model_name = "Normal1919/THW" model = AutoModelForImageClassification.from_pretrained(model_name) model.eval() # model = torch.compile(model) image_transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.697, 0.633, 0.635], std=[0.3135, 0.320, 0.315]) ]) with torch.no_grad(): image_raw = torchvision.io.read_image("test_img/c9f00dbb7e8fe20538fcc71b1dc0fbb913029959.png") if image_raw.size()[0] == 1: image_raw = torch.cat([image_raw]*3, 0) if image_raw.size()[0] == 4: image_raw = image_raw[:3] edit_image_tensor: torch.Tensor = image_transform(image_raw) edit_image_tensor = edit_image_tensor.unsqueeze(0) outputs = model(pixel_values=edit_image_tensor) logits = F.sigmoid(outputs.logits)[0] ind = logits.argmax().item() print(model.config.id2label[ind]) cha_names = [model.config.id2label[i] for i in range(146)] cha_probs = logits.numpy() names_probs = list(zip(cha_names, cha_probs)) names_probs = sorted(names_probs, key=lambda x: x[1], reverse=True) print(names_probs) top_k = 10 names_show = [] probs_show = [] for i in range(top_k): names_show.append(names_probs[i][0]) probs_show.append(names_probs[i][1]) plt.rcParams['font.sans-serif'] = ['SimHei'] plt.figure(figsize=(12, 8)) plt.bar(names_show, probs_show) plt.show() ```