| import pytorch_lightning as pl | |
| import torch | |
| from sklearn.cluster import KMeans | |
| import numpy as np | |
| class RBFNetwork(pl.LightningModule): | |
| def __init__( | |
| self, | |
| current_timestep, | |
| next_timestep, | |
| n_centers: int = 100, | |
| kappa: float = 1.0, | |
| lr=1e-2, | |
| datamodule=None, | |
| image_data=False, | |
| args=None | |
| ): | |
| super().__init__() | |
| self.K = n_centers | |
| self.current_timestep = current_timestep | |
| self.next_timestep = next_timestep | |
| self.clustering_model = KMeans(n_clusters=self.K) | |
| self.kappa = kappa | |
| self.last_val_loss = 1 | |
| self.lr = lr | |
| self.W = torch.nn.Parameter(torch.rand(self.K, 1)) | |
| self.datamodule = datamodule | |
| self.image_data = image_data | |
| self.args = args | |
| def on_before_zero_grad(self, *args, **kwargs): | |
| self.W.data = torch.clamp(self.W.data, min=0.0001) | |
| def on_train_start(self): | |
| with torch.no_grad(): | |
| batch = next(iter(self.trainer.datamodule.train_dataloader())) | |
| metric_samples = batch[0]["metric_samples"][0] | |
| all_data = torch.cat(metric_samples) | |
| data_to_fit = all_data | |
| print("Fitting Clustering model...") | |
| self.clustering_model.fit(data_to_fit) | |
| clusters = ( | |
| self.calculate_centroids(all_data, self.clustering_model.labels_) | |
| if self.image_data | |
| else self.clustering_model.cluster_centers_ | |
| ) | |
| self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device) | |
| labels = self.clustering_model.labels_ | |
| sigmas = np.zeros((self.K, 1)) | |
| for k in range(self.K): | |
| points = all_data[labels == k, :] | |
| variance = ((points - clusters[k]) ** 2).mean(axis=0) | |
| sigmas[k, :] = np.sqrt( | |
| variance.sum() if self.image_data else variance.mean() | |
| ) | |
| self.lamda = torch.tensor( | |
| 0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32 | |
| ).to(self.device) | |
| def forward(self, x): | |
| if len(x.shape) > 2: | |
| x = x.reshape(x.shape[0], -1).to(self.C.device) | |
| x = x.to(self.C.device) | |
| dist2 = torch.cdist(x, self.C) ** 2 | |
| self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None]) | |
| h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1) | |
| return h_x | |
| def training_step(self, batch, batch_idx): | |
| if self.args.data_type == "scrna" or self.args.data_type == "tahoe": | |
| main_batch = batch[0]["train_samples"][0] | |
| else: | |
| main_batch = batch["train_samples"][0] | |
| x0 = main_batch["x0"][0] | |
| if self.args.branches == 1: | |
| x1 = main_batch["x1"][0] | |
| inputs = torch.cat([x0, x1], dim=0).to(self.device) | |
| else: | |
| x1_1 = main_batch["x1_1"][0] | |
| x1_2 = main_batch["x1_2"][0] | |
| inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) | |
| print("inputs shape") | |
| print(inputs.shape) | |
| loss = ((1 - self.forward(inputs)) ** 2).mean() | |
| self.log( | |
| "MetricModel/train_loss_learn_metric", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| ) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| if self.args.data_type == "scrna" or self.args.data_type == "tahoe": | |
| main_batch = batch[0]["val_samples"][0] | |
| else: | |
| main_batch = batch["val_samples"][0] | |
| x0 = main_batch["x0"][0] | |
| if self.args.branches == 1: | |
| x1 = main_batch["x1"][0] | |
| inputs = torch.cat([x0, x1], dim=0).to(self.device) | |
| else: | |
| x1_1 = main_batch["x1_1"][0] | |
| x1_2 = main_batch["x1_2"][0] | |
| inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) | |
| h = self.forward(inputs) | |
| loss = ((1 - h) ** 2).mean() | |
| self.log( | |
| "MetricModel/val_loss_learn_metric", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| ) | |
| self.last_val_loss = loss.detach() | |
| return loss | |
| def calculate_centroids(self, all_data, labels): | |
| unique_labels = np.unique(labels) | |
| centroids = np.zeros((len(unique_labels), all_data.shape[1])) | |
| for i, label in enumerate(unique_labels): | |
| centroids[i] = all_data[labels == label].mean(axis=0) | |
| return centroids | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
| return optimizer | |
| def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False): | |
| if epsilon < 0: | |
| epsilon = (1 - self.last_val_loss.item()) / abs(epsilon) | |
| h_x = self.forward(x) | |
| if image_hx: | |
| h_x = 1 - torch.abs(1 - h_x) | |
| M_x = 1 / (h_x**alpha + epsilon) | |
| else: | |
| M_x = 1 / (h_x + epsilon) ** alpha | |
| return M_x |