| import typing as tp |
| import numpy as np |
| import torch |
| from torch import nn |
|
|
|
|
|
|
|
|
|
|
| class EncodecModel(nn.Module): |
|
|
| def __init__(self, |
| decoder=None, |
| quantizer=None, |
| frame_rate=None, |
| sample_rate=None, |
| channels=None, |
| causal=False, |
| renormalize=False): |
| |
| super().__init__() |
| self.frame_rate=0 |
| self.sample_rate=0 |
| self.channels=0 |
| self.decoder = decoder |
| self.quantizer = quantizer |
| self.frame_rate = frame_rate |
| self.sample_rate = sample_rate |
| self.channels = channels |
| self.renormalize = renormalize |
| self.causal = causal |
| if self.causal: |
| |
| |
| assert not self.renormalize, 'Causal model does not support renormalize' |
| |
|
|
| @property |
| def total_codebooks(self): |
| """Total number of quantizer codebooks available.""" |
| return self.quantizer.total_codebooks |
|
|
| @property |
| def num_codebooks(self): |
| """Active number of codebooks used by the quantizer.""" |
| return self.quantizer.num_codebooks |
|
|
| def set_num_codebooks(self, n): |
| """Set the active number of codebooks used by the quantizer.""" |
| self.quantizer.set_num_codebooks(n) |
|
|
| @property |
| def cardinality(self): |
| """Cardinality of each codebook.""" |
| return self.quantizer.bins |
|
|
| def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| scale: tp.Optional[torch.Tensor] |
| if self.renormalize: |
| mono = x.mean(dim=1, keepdim=True) |
| volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
| scale = 1e-8 + volume |
| x = x / scale |
| scale = scale.view(-1, 1) |
| else: |
| scale = None |
| return x, scale |
|
|
| def postprocess(self, |
| x: torch.Tensor, |
| scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: |
| if scale is not None: |
| assert self.renormalize |
| x = x * scale.view(-1, 1, 1) |
| return x |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| |
| emb = self.decode_latent(codes) |
| |
| out = self.decoder(emb) |
| |
| out = self.postprocess(out, scale) |
| |
| return out |
|
|
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| return self.quantizer.decode(codes) |