Spaces:
Build error
Build error
| import random | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import threestudio | |
| from threestudio.models.background.base import BaseBackground | |
| from threestudio.models.networks import get_encoding, get_mlp | |
| from threestudio.utils.ops import get_activation | |
| from threestudio.utils.typing import * | |
| class NeuralEnvironmentMapBackground(BaseBackground): | |
| class Config(BaseBackground.Config): | |
| n_output_dims: int = 3 | |
| color_activation: str = "sigmoid" | |
| dir_encoding_config: dict = field( | |
| default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} | |
| ) | |
| mlp_network_config: dict = field( | |
| default_factory=lambda: { | |
| "otype": "VanillaMLP", | |
| "activation": "ReLU", | |
| "n_neurons": 16, | |
| "n_hidden_layers": 2, | |
| } | |
| ) | |
| random_aug: bool = False | |
| random_aug_prob: float = 0.5 | |
| eval_color: Optional[Tuple[float, float, float]] = None | |
| cfg: Config | |
| def configure(self) -> None: | |
| self.encoding = get_encoding(3, self.cfg.dir_encoding_config) | |
| self.network = get_mlp( | |
| self.encoding.n_output_dims, | |
| self.cfg.n_output_dims, | |
| self.cfg.mlp_network_config, | |
| ) | |
| def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: | |
| if not self.training and self.cfg.eval_color is not None: | |
| return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( | |
| dirs | |
| ) * torch.as_tensor(self.cfg.eval_color).to(dirs) | |
| # viewdirs must be normalized before passing to this function | |
| dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) | |
| dirs_embd = self.encoding(dirs.view(-1, 3)) | |
| color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) | |
| color = get_activation(self.cfg.color_activation)(color) | |
| if ( | |
| self.training | |
| and self.cfg.random_aug | |
| and random.random() < self.cfg.random_aug_prob | |
| ): | |
| # use random background color with probability random_aug_prob | |
| color = color * 0 + ( # prevent checking for unused parameters in DDP | |
| torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) | |
| .to(dirs) | |
| .expand(*dirs.shape[:-1], -1) | |
| ) | |
| return color | |