AlienChen commited on
Commit
3c92e07
·
verified ·
1 Parent(s): df3a13c

Delete modules/dna_module.py

Browse files
Files changed (1) hide show
  1. modules/dna_module.py +0 -301
modules/dna_module.py DELETED
@@ -1,301 +0,0 @@
1
- import copy
2
- import math
3
- from collections import defaultdict
4
-
5
- import PIL
6
- import numpy as np
7
- import pandas as pd
8
- import torch, time, os
9
- import wandb
10
- import seaborn as sns
11
- import yaml
12
-
13
- sns.set_style('whitegrid')
14
- from matplotlib import pyplot as plt
15
- from torch import optim
16
-
17
- from models.dna_models import MLPModel, CNNModel, TransformerModel, DeepFlyBrainModel
18
- from utils.flow_utils import DirichletConditionalFlow, expand_simplex, sample_cond_prob_path, simplex_proj, \
19
- get_wasserstein_dist, update_ema, load_flybrain_designed_seqs
20
- from modules.general_module import GeneralModule
21
- from utils.log import get_logger
22
-
23
- from flow_matching.path import MixtureDiscreteProbPath
24
- from flow_matching.path.scheduler import PolynomialConvexScheduler
25
- from flow_matching.solver import MixtureDiscreteEulerSolver
26
- from flow_matching.utils import ModelWrapper
27
- from flow_matching.loss import MixturePathGeneralizedKL
28
-
29
- import pdb
30
-
31
-
32
- logger = get_logger(__name__)
33
-
34
-
35
- class DNAModule(GeneralModule):
36
- def __init__(self, args, alphabet_size, num_cls, source_distribution="uniform"):
37
- super().__init__(args)
38
- self.alphabet_size = alphabet_size
39
- self.source_distribution = source_distribution
40
- self.epsilon = 1e-3
41
-
42
- if source_distribution == "uniform":
43
- added_token = 0
44
- elif source_distribution == "mask":
45
- self.mask_token = alphabet_size # tokens starting from zero
46
- added_token = 1
47
- else:
48
- raise NotImplementedError
49
- self.alphabet_size += added_token
50
-
51
- self.load_model(self.alphabet_size, num_cls)
52
-
53
- self.scheduler = PolynomialConvexScheduler(n=args.scheduler_n)
54
- self.path = MixtureDiscreteProbPath(scheduler=self.scheduler)
55
- self.loss_fn = MixturePathGeneralizedKL(path=self.path)
56
-
57
- self.val_outputs = defaultdict(list)
58
- self.train_outputs = defaultdict(list)
59
- self.train_out_initialized = False
60
- self.mean_log_ema = {}
61
- if self.args.taskiran_seq_path is not None:
62
- self.taskiran_fly_seqs = load_flybrain_designed_seqs(self.args.taskiran_seq_path).to(self.device)
63
-
64
- def on_load_checkpoint(self, checkpoint):
65
- checkpoint['state_dict'] = {k: v for k,v in checkpoint['state_dict'].items() if 'cls_model' not in k and 'distill_model' not in k}
66
-
67
- def training_step(self, batch, batch_idx):
68
- self.stage = 'train'
69
- loss = self.general_step(batch, batch_idx)
70
- if self.args.ckpt_iterations is not None and self.trainer.global_step in self.args.ckpt_iterations:
71
- self.trainer.save_checkpoint(os.path.join(os.environ["MODEL_DIR"],f"epoch={self.trainer.current_epoch}-step={self.trainer.global_step}.ckpt"))
72
- # self.try_print_log()
73
- return loss
74
-
75
- def validation_step(self, batch, batch_idx):
76
- self.stage = 'val'
77
- loss = self.general_step(batch, batch_idx)
78
- # if self.args.validate:
79
- # self.try_print_log()
80
-
81
- def general_step(self, batch, batch_idx=None):
82
- self.iter_step += 1
83
- x_1, cls = batch
84
- B, L = x_1.shape
85
- x_1 = x_1.to(self.device)
86
-
87
- if self.source_distribution == "uniform":
88
- x_0 = torch.randint_like(x_1, high=self.alphabet_size)
89
- elif self.source_distribution == "mask":
90
- x_0 = torch.zeros_like(x_1) + self.mask_token
91
- else:
92
- raise NotImplementedError
93
- # pdb.set_trace()
94
- t = torch.rand(x_1.shape[0]) * (1 - self.epsilon)
95
- t = t.to(x_1.device)
96
- path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1)
97
-
98
- logits = self.model(x_t=path_sample.x_t, t=path_sample.t)
99
- loss = self.loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t)
100
- # pdb.set_trace()
101
-
102
- self.lg('loss', loss)
103
- if self.stage == "val":
104
- predicted = logits.argmax(dim=-1)
105
- accuracy = (predicted == x_1).float().mean()
106
- self.lg('acc', accuracy)
107
- self.last_log_time = time.time()
108
- return loss
109
-
110
- @torch.no_grad()
111
- def dirichlet_flow_inference(self, seq, cls, model, args):
112
- B, L = seq.shape
113
- K = model.alphabet_size
114
- x0 = torch.distributions.Dirichlet(torch.ones(B, L, model.alphabet_size, device=seq.device)).sample()
115
- eye = torch.eye(K).to(x0)
116
- xt = x0.clone()
117
-
118
- t_span = torch.linspace(1, args.alpha_max, self.args.num_integration_steps, device=self.device)
119
- for i, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])):
120
- xt_expanded, prior_weights = expand_simplex(xt, s[None].expand(B), args.prior_pseudocount)
121
-
122
- logits = model(xt_expanded, t=s[None].expand(B))
123
- flow_probs = torch.nn.functional.softmax(logits / args.flow_temp, -1) # [B, L, K]
124
-
125
- if not torch.allclose(flow_probs.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (flow_probs >= 0).all():
126
- print(f'WARNING: flow_probs.min(): {flow_probs.min()}. Some values of flow_probs do not lie on the simplex. There are we are {(flow_probs<0).sum()} negative values in flow_probs of shape {flow_probs.shape} that are negative. We are projecting them onto the simplex.')
127
- flow_probs = simplex_proj(flow_probs)
128
-
129
- c_factor = self.condflow.c_factor(xt.cpu().numpy(), s.item())
130
- c_factor = torch.from_numpy(c_factor).to(xt)
131
-
132
- self.inf_counter += 1
133
-
134
- if not (flow_probs >= 0).all(): print(f'flow_probs.min(): {flow_probs.min()}')
135
- cond_flows = (eye - xt.unsqueeze(-1)) * c_factor.unsqueeze(-2)
136
- flow = (flow_probs.unsqueeze(-2) * cond_flows).sum(-1)
137
-
138
- xt = xt + flow * (t - s)
139
-
140
- if not torch.allclose(xt.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (xt >= 0).all():
141
- print(f'WARNING: xt.min(): {xt.min()}. Some values of xt do not lie on the simplex. There are we are {(xt<0).sum()} negative values in xt of shape {xt.shape} that are negative. We are projecting them onto the simplex.')
142
- xt = simplex_proj(xt)
143
- return logits, x0
144
-
145
- def on_validation_epoch_start(self):
146
- self.inf_counter = 1
147
- self.nan_inf_counter = 0
148
-
149
- def on_validation_epoch_end(self):
150
- self.generator = np.random.default_rng()
151
- log = self._log
152
- log = {key: log[key] for key in log if "val_" in key}
153
- log = self.gather_log(log, self.trainer.world_size)
154
- mean_log = self.get_log_mean(log)
155
- mean_log.update({'val_nan_inf_step_fraction': self.nan_inf_counter / self.inf_counter})
156
-
157
- mean_log.update({'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)})
158
-
159
- self.mean_log_ema = update_ema(current_dict=mean_log, prev_ema=self.mean_log_ema, gamma=0.9)
160
- mean_log.update(self.mean_log_ema)
161
- if self.trainer.is_global_zero:
162
- logger.info(str(mean_log))
163
- self.log_dict(mean_log, batch_size=1)
164
- if self.args.wandb:
165
- wandb.log(mean_log)
166
-
167
- path = os.path.join(os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv")
168
- pd.DataFrame(log).to_csv(path)
169
-
170
- for key in list(log.keys()):
171
- if "val_" in key:
172
- del self._log[key]
173
- self.val_outputs = defaultdict(list)
174
-
175
-
176
- def on_train_epoch_start(self) -> None:
177
- self.inf_counter = 1
178
- self.nan_inf_counter = 0
179
- # if not self.loaded_distill_model and self.args.distill_ckpt is not None:
180
- # self.load_distill_model()
181
- # self.loaded_distill_model = True
182
- # if not self.loaded_classifiers:
183
- # self.load_classifiers(load_cls=self.args.cls_ckpt is not None, load_clean_cls=self.args.clean_cls_ckpt is not None)
184
- # self.loaded_classifiers = True
185
-
186
- def on_train_epoch_end(self):
187
- self.train_out_initialized = True
188
- log = self._log
189
- log = {key: log[key] for key in log if "train_" in key}
190
- log = self.gather_log(log, self.trainer.world_size)
191
- mean_log = self.get_log_mean(log)
192
- mean_log.update(
193
- {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)})
194
-
195
- if self.trainer.is_global_zero:
196
- logger.info(str(mean_log))
197
- self.log_dict(mean_log, batch_size=1)
198
- if self.args.wandb:
199
- wandb.log(mean_log)
200
-
201
- for key in list(log.keys()):
202
- if "train_" in key:
203
- del self._log[key]
204
-
205
- def lg(self, key, data):
206
- if isinstance(data, torch.Tensor):
207
- data = data.detach().cpu().numpy()
208
- log = self._log
209
- if self.args.validate or self.stage == 'train':
210
- log["iter_" + key].append(data)
211
- log[self.stage + "_" + key].append(data)
212
-
213
- def configure_optimizers(self):
214
- optimizer = optim.Adam(self.parameters(), lr=self.args.lr)
215
- return optimizer
216
-
217
- def plot_empirical_and_true(self, empirical_dist, true_dist):
218
- num_datasets_to_plot = min(4, empirical_dist.shape[0])
219
- width = 1
220
- # Creating a figure and axes
221
- fig, axes = plt.subplots(math.ceil(num_datasets_to_plot/2), 2, figsize=(10, 8))
222
- for i in range(num_datasets_to_plot):
223
- row, col = i // 2, i % 2
224
- x = np.arange(len(empirical_dist[i]))
225
- axes[row, col].bar(x, empirical_dist[i], width, label=f'empirical')
226
- axes[row, col].plot(x, true_dist[i], label=f'true density', color='orange')
227
- axes[row, col].legend()
228
- axes[row, col].set_title(f'Sequence position {i + 1}')
229
- axes[row, col].set_xlabel('Category')
230
- axes[row, col].set_ylabel('Density')
231
- plt.tight_layout()
232
- fig.canvas.draw()
233
- pil_img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
234
- plt.close()
235
- return pil_img
236
-
237
- def load_model(self, alphabet_size, num_cls):
238
- if self.args.model == 'cnn':
239
- self.model = CNNModel(self.args, alphabet_size=alphabet_size)
240
- elif self.args.model == 'mlp':
241
- self.model = MLPModel(input_dim=alphabet_size, time_dim=1, hidden_dim=self.args.hidden_dim, length=self.args.length)
242
- elif self.args.model == 'transformer':
243
- self.model = TransformerModel(alphabet_size=alphabet_size, seq_length=self.args.length, embed_dim=self.args.hidden_dim, \
244
- num_layers=self.args.num_layers, num_heads=self.args.num_heads, dropout=self.args.dropout)
245
- elif self.args.model == 'deepflybrain':
246
- self.model = DeepFlyBrainModel(self.args, alphabet_size=alphabet_size,num_cls=num_cls)
247
- else:
248
- raise NotImplementedError()
249
-
250
- def plot_score_and_probs(self):
251
- clss = torch.cat(self.val_outputs['clss_noisycls'])
252
- probs = torch.softmax(torch.cat(self.val_outputs['logits_noisycls']), dim=-1)
253
- scores = torch.cat(self.val_outputs['scores_noisycls']).cpu().numpy()
254
- score_norms = np.linalg.norm(scores, axis=-1)
255
- alphas = torch.cat(self.val_outputs['alphas_noisycls']).cpu().numpy()
256
- true_probs = probs[torch.arange(len(probs)), clss].cpu().numpy()
257
- bins = np.linspace(min(alphas), 12, 20)
258
- indices = np.digitize(alphas, bins)
259
- bin_means = [np.mean(true_probs[indices == i]) for i in range(1, len(bins))]
260
- bin_std = [np.std(true_probs[indices == i]) for i in range(1, len(bins))]
261
- bin_centers = 0.5 * (bins[:-1] + bins[1:])
262
-
263
- bin_pos_std = [np.std(true_probs[indices == i][true_probs[indices == i] > np.mean(true_probs[indices == i])]) for i in range(1, len(bins))]
264
- bin_neg_std = [np.std(true_probs[indices == i][true_probs[indices == i] < np.mean(true_probs[indices == i])]) for i in range(1, len(bins))]
265
- plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std})
266
- plt.figure(figsize=(10, 6))
267
- sns.lineplot(x='Alphas', y='Means', data=plot_data)
268
- plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3)
269
- plt.xlabel('Binned alphas values')
270
- plt.ylabel('Mean of predicted probs for true class')
271
- fig = plt.gcf()
272
- fig.canvas.draw()
273
- pil_probs = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
274
-
275
- plt.close()
276
- bin_means = [np.mean(score_norms[indices == i]) for i in range(1, len(bins))]
277
- bin_std = [np.std(score_norms[indices == i]) for i in range(1, len(bins))]
278
- bin_pos_std = [np.std(score_norms[indices == i][score_norms[indices == i] > np.mean(score_norms[indices == i])]) for i in range(1, len(bins))]
279
- bin_neg_std = [np.std(score_norms[indices == i][score_norms[indices == i] < np.mean(score_norms[indices == i])]) for i in range(1, len(bins))]
280
- plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std})
281
- plt.figure(figsize=(10, 6))
282
- sns.lineplot(x='Alphas', y='Means', data=plot_data)
283
- plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'],
284
- plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3)
285
- plt.xlabel('Binned alphas values')
286
- plt.ylabel('Mean of norm of the scores')
287
- fig = plt.gcf()
288
- fig.canvas.draw()
289
- pil_score_norms = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
290
- return pil_probs, pil_score_norms
291
-
292
- def log_data_similarities(self, seq_pred):
293
- similarities1 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class1[None, :, :]) # batchsize, dataset_size, seq_len
294
- similarities2 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class2[None, :, :]) # batchsize, dataset_size, seq_len
295
- similarities = seq_pred.cpu()[:, None, :].eq(torch.cat([self.toy_data.data_class2[None, :, :], self.toy_data.data_class1[None, :, :]],dim=1)) # batchsize, dataset_size, seq_len
296
- self.lg('data1_sim', similarities1.float().mean(-1).max(-1)[0])
297
- self.lg('data2_sim', similarities2.float().mean(-1).max(-1)[0])
298
- self.lg('data_sim', similarities.float().mean(-1).max(-1)[0])
299
- self.lg('mean_data1_sim', similarities1.float().mean(-1).mean(-1))
300
- self.lg('mean_data2_sim', similarities2.float().mean(-1).mean(-1))
301
- self.lg('mean_data_sim', similarities.float().mean(-1).mean(-1))