import torch import torch.nn as nn import segmentation_models_pytorch as smp class UNet(nn.Module): """ UNet model for multi-class segmentation. Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes. """ def __init__(self, encoder_name='tu-regnetz_d8', encoder_weights=None, in_channels=13, # Number of input channels (13 for Sentinel-2 multi-spectral images) num_classes=4, # Number of output classes (e.g., clear, thick cloud, thin cloud, cloud shadow) freeze_encoder=False): # Whether to freeze the encoder's weights """ Args: encoder_weights (str or None): Weights for the encoder, typically 'imagenet' or None. in_channels (int): Number of input channels (e.g., 13 for Sentinel-2 images). num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow). freeze_encoder (bool): If True, freezes the encoder weights during training. """ super(UNet, self).__init__() self.unet = smp.Unet( encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=in_channels, classes=num_classes, ) if freeze_encoder: for param in self.unet.encoder.parameters(): param.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the model. Args: x (torch.Tensor): Input tensor of shape (B, in_channels, H, W). Returns: torch.Tensor: Output logits of shape (B, num_classes, H, W). """ return self.unet(x) @torch.no_grad() def predict(self, x: torch.Tensor) -> torch.Tensor: """ Predicts multi-class segmentation labels for each pixel in the input image. Args: x (torch.Tensor): Input tensor of shape (B, in_channels, H, W). Returns: torch.Tensor: Predicted labels of shape (B, H, W). """ self.eval() logits = self.forward(x) # (B, num_classes, H, W) probs = torch.softmax(logits, dim=1) # (B, num_classes, H, W) labels = probs.argmax(dim=1) # (B, H, W) return labels