| image_path = './image001.png' |
| sentence = 'spoon on the dish' |
| weights = 'checkpoints/gradio.pth' |
| device = 'cpu' |
|
|
| |
| from PIL import Image |
| import torchvision.transforms as T |
| import numpy as np |
| import datetime |
| import os |
| import time |
|
|
| import torch |
| import torch.utils.data |
| from torch import nn |
|
|
| from bert.multimodal_bert import MultiModalBert |
| import torchvision |
|
|
| from lib import multimodal_segmentation_ppm |
| |
| import utils |
|
|
| import numpy as np |
| from PIL import Image |
| import torch.nn.functional as F |
|
|
| from modeling.MaskFormerModel import MaskFormerHead |
| from addict import Dict |
| |
| import cv2 |
| import textwrap |
|
|
| class WrapperModel(nn.Module): |
| def __init__(self, image_model, language_model, classifier) : |
| super(WrapperModel, self).__init__() |
| self.image_model = image_model |
| self.language_model = language_model |
| self.classifier = classifier |
|
|
| config = Dict({ |
| "architectures": [ |
| "BertForMaskedLM" |
| ], |
| "attention_probs_dropout_prob": 0.1, |
| "gradient_checkpointing": False, |
| "hidden_act": "gelu", |
| "hidden_dropout_prob": 0.1, |
| "hidden_size": 512, |
| "initializer_range": 0.02, |
| "intermediate_size": 3072, |
| "layer_norm_eps": 1e-12, |
| |
| "model_type": "bert", |
| "num_attention_heads": 8, |
| "num_hidden_layers": 8, |
| "pad_token_id": 0, |
| "position_embedding_type": "absolute", |
| "transformers_version": "4.6.0.dev0", |
| "type_vocab_size": 2, |
| "use_cache": True, |
| "vocab_size": 30522 |
| }) |
|
|
|
|
|
|
| def _get_binary_mask(self, target): |
| |
| y, x = target.size() |
| target_onehot = torch.zeros(self.num_classes + 1, y, x) |
| target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) |
| return target_onehot[1:] |
|
|
| def semantic_inference(self, mask_cls, mask_pred): |
| mask_cls = F.softmax(mask_cls, dim=1)[...,1:] |
| mask_pred = mask_pred.sigmoid() |
| semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) |
| return semseg |
|
|
| def forward(self, image, sentences, attentions): |
| print(image.sum(), sentences.sum(), attentions.sum()) |
| input_shape = image.shape[-2:] |
| l_mask = attentions.unsqueeze(dim=-1) |
|
|
| i0, Wh, Ww = self.image_model.forward_stem(image) |
| l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) |
|
|
| i1 = self.image_model.forward_stage1(i0, Wh, Ww) |
| l1 = self.language_model.forward_stage1(l0, extended_attention_mask) |
| i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) |
| l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) |
| i1 = i1_temp |
|
|
| i2 = self.image_model.forward_stage2(i1, Wh, Ww) |
| l2 = self.language_model.forward_stage2(l1, extended_attention_mask) |
| i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) |
| l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) |
| i2 = i2_temp |
|
|
| i3 = self.image_model.forward_stage3(i2, Wh, Ww) |
| l3 = self.language_model.forward_stage3(l2, extended_attention_mask) |
| i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) |
| l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) |
| i3 = i3_temp |
|
|
| i4 = self.image_model.forward_stage4(i3, Wh, Ww) |
| l4 = self.language_model.forward_stage4(l3, extended_attention_mask) |
| i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) |
| l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) |
| i4 = i4_temp |
|
|
| |
| |
| |
| outputs = {} |
| outputs['s1'] = i1_residual |
| outputs['s2'] = i2_residual |
| outputs['s3'] = i3_residual |
| outputs['s4'] = i4_residual |
|
|
| predictions = self.classifier(outputs) |
| return predictions |
|
|
| |
| img = Image.open(image_path).convert("RGB") |
| img_ndarray = np.array(img) |
| original_w, original_h = img.size |
|
|
| image_transforms = T.Compose( |
| [ |
| T.Resize((480, 480)), |
| T.ToTensor(), |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
| ) |
|
|
| img = image_transforms(img).unsqueeze(0) |
| img = img.to(device) |
|
|
| |
| from bert.tokenization_bert import BertTokenizer |
| import torch |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True) |
| sentence_tokenized = sentence_tokenized[:20] |
| |
| padded_sent_toks = [0] * 20 |
| padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized |
| |
| attention_mask = [0] * 20 |
| attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized) |
| |
| padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) |
| attention_mask = torch.tensor(attention_mask).unsqueeze(0) |
| padded_sent_toks = padded_sent_toks.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| class args: |
| swin_type = 'base' |
| window12 = True |
| mha = '' |
| fusion_drop = 0.0 |
|
|
|
|
| |
| single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args) |
| single_model.to(device) |
| model_class = MultiModalBert |
| single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim) |
| single_bert_model.pooler = None |
|
|
| input_shape = dict() |
| input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) |
| input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) |
| input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) |
| input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) |
|
|
|
|
|
|
| cfg = Dict() |
| cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 |
| cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 |
| cfg.MODEL.MASK_FORMER.NHEADS = 8 |
| cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 |
| cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 |
| cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 |
| cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] |
|
|
| cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 |
| cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 |
| cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 |
| cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 |
| cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 |
| cfg.MODEL.MASK_FORMER.PRE_NORM = False |
|
|
|
|
| maskformer_head = MaskFormerHead(cfg, input_shape) |
|
|
|
|
| model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head) |
|
|
|
|
|
|
| checkpoint = torch.load(weights, map_location='cpu') |
|
|
| model.load_state_dict(checkpoint['model'], strict=False) |
| model.to(device) |
| model.eval() |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| output = model(img, padded_sent_toks, attention_mask)[0] |
| |
| |
| mask_cls_results = output["pred_logits"] |
| mask_pred_results = output["pred_masks"] |
|
|
| target_shape = img_ndarray.shape[:2] |
| |
| mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True) |
|
|
| pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) |
| |
|
|
| |
|
|
|
|
|
|
| |
| |
| |
| output = torch.nn.functional.interpolate(pred_masks, target_shape) |
| output = (output > 0.5).data.cpu().numpy() |
|
|
|
|
| |
| def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4): |
| from scipy.ndimage.morphology import binary_dilation |
|
|
| colors = np.reshape(colors, (-1, 3)) |
| colors = np.atleast_2d(colors) * cscale |
|
|
| im_overlay = image.copy() |
| object_ids = np.unique(mask) |
|
|
| for object_id in object_ids[1:]: |
| |
| foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) |
| binary_mask = mask == object_id |
|
|
| |
| im_overlay[binary_mask] = foreground[binary_mask] |
|
|
| |
| countours = binary_dilation(binary_mask) ^ binary_mask |
| |
| im_overlay[countours, :] = 0 |
|
|
| return im_overlay.astype(image.dtype) |
|
|
|
|
| output = output.astype(np.uint8) |
| |
| print(img_ndarray.shape, output.shape) |
| visualization = overlay_davis(img_ndarray, output[0][0]) |
| visualization = Image.fromarray(visualization) |
| |
| |
| |
| visualization.save('./demo/spoon_on_the_dish.jpg') |
|
|
|
|
|
|
|
|
|
|