sophiat44
model upload
5a87d8d
import pytorch_lightning as pl
import torch
from sklearn.cluster import KMeans
import numpy as np
class RBFNetwork(pl.LightningModule):
def __init__(
self,
current_timestep,
next_timestep,
n_centers: int = 100,
kappa: float = 1.0,
lr=1e-2,
datamodule=None,
image_data=False,
args=None
):
super().__init__()
self.K = n_centers
self.current_timestep = current_timestep
self.next_timestep = next_timestep
self.clustering_model = KMeans(n_clusters=self.K)
self.kappa = kappa
self.last_val_loss = 1
self.lr = lr
self.W = torch.nn.Parameter(torch.rand(self.K, 1))
self.datamodule = datamodule
self.image_data = image_data
self.args = args
def on_before_zero_grad(self, *args, **kwargs):
self.W.data = torch.clamp(self.W.data, min=0.0001)
def on_train_start(self):
with torch.no_grad():
batch = next(iter(self.trainer.datamodule.train_dataloader()))
metric_samples = batch[0]["metric_samples"][0]
all_data = torch.cat(metric_samples)
data_to_fit = all_data
print("Fitting Clustering model...")
self.clustering_model.fit(data_to_fit)
clusters = (
self.calculate_centroids(all_data, self.clustering_model.labels_)
if self.image_data
else self.clustering_model.cluster_centers_
)
self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device)
labels = self.clustering_model.labels_
sigmas = np.zeros((self.K, 1))
for k in range(self.K):
points = all_data[labels == k, :]
variance = ((points - clusters[k]) ** 2).mean(axis=0)
sigmas[k, :] = np.sqrt(
variance.sum() if self.image_data else variance.mean()
)
self.lamda = torch.tensor(
0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32
).to(self.device)
def forward(self, x):
if len(x.shape) > 2:
x = x.reshape(x.shape[0], -1).to(self.C.device)
x = x.to(self.C.device)
dist2 = torch.cdist(x, self.C) ** 2
self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None])
h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1)
return h_x
def training_step(self, batch, batch_idx):
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
main_batch = batch[0]["train_samples"][0]
else:
main_batch = batch["train_samples"][0]
x0 = main_batch["x0"][0]
if self.args.branches == 1:
x1 = main_batch["x1"][0]
inputs = torch.cat([x0, x1], dim=0).to(self.device)
else:
x1_1 = main_batch["x1_1"][0]
x1_2 = main_batch["x1_2"][0]
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
print("inputs shape")
print(inputs.shape)
loss = ((1 - self.forward(inputs)) ** 2).mean()
self.log(
"MetricModel/train_loss_learn_metric",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
main_batch = batch[0]["val_samples"][0]
else:
main_batch = batch["val_samples"][0]
x0 = main_batch["x0"][0]
if self.args.branches == 1:
x1 = main_batch["x1"][0]
inputs = torch.cat([x0, x1], dim=0).to(self.device)
else:
x1_1 = main_batch["x1_1"][0]
x1_2 = main_batch["x1_2"][0]
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
h = self.forward(inputs)
loss = ((1 - h) ** 2).mean()
self.log(
"MetricModel/val_loss_learn_metric",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.last_val_loss = loss.detach()
return loss
def calculate_centroids(self, all_data, labels):
unique_labels = np.unique(labels)
centroids = np.zeros((len(unique_labels), all_data.shape[1]))
for i, label in enumerate(unique_labels):
centroids[i] = all_data[labels == label].mean(axis=0)
return centroids
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False):
if epsilon < 0:
epsilon = (1 - self.last_val_loss.item()) / abs(epsilon)
h_x = self.forward(x)
if image_hx:
h_x = 1 - torch.abs(1 - h_x)
M_x = 1 / (h_x**alpha + epsilon)
else:
M_x = 1 / (h_x + epsilon) ** alpha
return M_x