pmukhop commited on
Commit
3679789
·
0 Parent(s):

initial commit post neutron star

Browse files
Files changed (2) hide show
  1. coalesced.pth +3 -0
  2. extended_config.yaml +269 -0
coalesced.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:118dc6735c440435c7765096c46c117cb28b18cdb90e8dbb3457bc63b406f91a
3
+ size 5179778007
extended_config.yaml ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_workers: 10
2
+ name: WalrusFT_pns_1step_realglobalnorm-postn-delta-Isotr[Space-Adapt-]-AdamW-8e-05
3
+ finetune: true
4
+ automatic_setup: true
5
+ trainer:
6
+ _target_: walrus.trainer.Trainer
7
+ max_epoch: 50
8
+ val_frequency: 5
9
+ rollout_val_frequency: 5
10
+ short_validation_length: 20
11
+ max_rollout_steps: 30
12
+ num_time_intervals: 5
13
+ enable_amp: false
14
+ loss_fn:
15
+ _target_: the_well.benchmark.metrics.MAE
16
+ formatter:
17
+ _target_: hydra.utils.get_class
18
+ path: walrus.data.well_to_multi_transformer.ChannelsFirstWithTimeFormatter
19
+ revin:
20
+ _target_: walrus.trainer.normalization_strat.GlobalRevNormalization
21
+ _partial_: true
22
+ prediction_type: delta
23
+ grad_acc_steps: 1
24
+ image_validation: true
25
+ video_validation: true
26
+ gradient_log_level: 0
27
+ clip_gradient: 10
28
+ log_interval: 200
29
+ loss_multiplier: 100.0
30
+ lr_scheduler_per_step: false
31
+ validation_suite:
32
+ - _target_: the_well.benchmark.metrics.NRMSE
33
+ - _target_: the_well.benchmark.metrics.VRMSE
34
+ - _target_: the_well.benchmark.metrics.PearsonR
35
+ batch_aggregation_fns:
36
+ - torch.mean
37
+ - torch.median
38
+ - torch.std
39
+ skip_spectral_metrics: true
40
+ optimizer:
41
+ _target_: torch.optim.AdamW
42
+ weight_decay: 0.0001
43
+ eps: 1.0e-10
44
+ lr: 8.0e-05
45
+ lr_scheduler:
46
+ _target_: walrus.optim.schedulers.InverseSqrtLinearWarmupSqrtCooldown
47
+ warmup_epochs: 10
48
+ cooldown_epochs: 10
49
+ warmup_lr_factor: 0.1
50
+ cooldown_lr_factor: 0.001
51
+ model:
52
+ encoder:
53
+ _partial_: true
54
+ _target_: walrus.models.encoders.vstride_encoder.SpaceBagAdaptiveDVstrideEncoder
55
+ learned_pad: true
56
+ base_kernel_size1d:
57
+ - - 4
58
+ - 4
59
+ base_kernel_size2d:
60
+ - - 8
61
+ - 4
62
+ - - 8
63
+ - 4
64
+ base_kernel_size3d:
65
+ - - 8
66
+ - 4
67
+ - - 8
68
+ - 4
69
+ - - 8
70
+ - 4
71
+ groups: 12
72
+ kernel_scales_seq:
73
+ - - 2
74
+ - 2
75
+ - - 4
76
+ - 2
77
+ - - 4
78
+ - 4
79
+ - - 8
80
+ - 4
81
+ variable_downsample: true
82
+ variable_deterministic_ds: true
83
+ activation:
84
+ _partial_: true
85
+ _target_: torch.nn.SiLU
86
+ decoder:
87
+ _partial_: true
88
+ _target_: walrus.models.decoders.vstride_decoder.AdaptiveDVstrideDecoder
89
+ learned_pad: true
90
+ base_kernel_size1d:
91
+ - - 4
92
+ - 4
93
+ base_kernel_size2d:
94
+ - - 8
95
+ - 4
96
+ - - 8
97
+ - 4
98
+ base_kernel_size3d:
99
+ - - 8
100
+ - 4
101
+ - - 8
102
+ - 4
103
+ - - 8
104
+ - 4
105
+ groups: 12
106
+ activation:
107
+ _partial_: true
108
+ _target_: torch.nn.SiLU
109
+ processor:
110
+ space_mixing:
111
+ _partial_: true
112
+ _target_: walrus.models.spatial_blocks.full_attention.FullAttention
113
+ num_heads: 16
114
+ mlp_dim: null
115
+ time_mixing:
116
+ _partial_: true
117
+ _target_: walrus.models.temporal_blocks.axial_time_attention.AxialTimeAttention
118
+ num_heads: 16
119
+ bias_type: rel
120
+ channel_mixing:
121
+ _partial_: true
122
+ _target_: torch.nn.Identity
123
+ _partial_: true
124
+ _target_: walrus.models.spatiotemporal_blocks.space_time_split.SpaceTimeSplitBlock
125
+ norm_layer:
126
+ _partial_: true
127
+ _target_: walrus.models.shared_utils.normalization.RMSGroupNorm
128
+ _target_: walrus.models.IsotropicModel
129
+ hidden_dim: 1408
130
+ projection_dim: 48
131
+ intermediate_dim: 352
132
+ processor_blocks: 40
133
+ drop_path: 0.0
134
+ groups: 16
135
+ max_d: 3
136
+ static_axes: true
137
+ weight_tied_axes: false
138
+ causal_in_time: true
139
+ include_d:
140
+ - 2
141
+ - 3
142
+ override_dimensionality: 0
143
+ jitter_patches: true
144
+ gradient_checkpointing_freq: 2
145
+ use_periodic_fixed_jitter: true
146
+ input_field_drop: 0
147
+ data:
148
+ field_index_map_override:
149
+ closed_boundary: 0
150
+ open_boundary: 1
151
+ bias_correction: 2
152
+ pressure: 3
153
+ velocity_x: 4
154
+ velocity_y: 5
155
+ velocity_z: 6
156
+ zeros_like_density: 7
157
+ speed_of_sound: 8
158
+ concentration: 9
159
+ D_xx: 10
160
+ D_xy: 11
161
+ D_xz: 12
162
+ D_yx: 13
163
+ D_yy: 14
164
+ D_yz: 15
165
+ D_zx: 16
166
+ D_zy: 17
167
+ D_zz: 18
168
+ E_xx: 19
169
+ E_xy: 20
170
+ E_xz: 21
171
+ E_yx: 22
172
+ E_yy: 23
173
+ E_yz: 24
174
+ E_zx: 25
175
+ E_zy: 26
176
+ E_zz: 27
177
+ density: 28
178
+ energy: 29
179
+ velocity_r: 30
180
+ velocity_theta: 31
181
+ velocity_phi: 32
182
+ momentum_x: 33
183
+ momentum_y: 34
184
+ momentum_z: 35
185
+ pressure_re: 36
186
+ pressure_im: 37
187
+ mask: 38
188
+ magnetic_field_x: 39
189
+ magnetic_field_y: 40
190
+ magnetic_field_z: 41
191
+ A: 42
192
+ B: 43
193
+ height: 44
194
+ internal_energy: 45
195
+ temperature: 46
196
+ electron_fraction: 47
197
+ entropy: 48
198
+ magnetic_field_log_r: 49
199
+ magnetic_field_theta: 50
200
+ magnetic_field_phi: 51
201
+ velocity_log_r: 52
202
+ buoyancy: 53
203
+ tracer: 54
204
+ log10_density: 55
205
+ log10_temperature: 56
206
+ c_zz: 57
207
+ C_xx: 58
208
+ C_xy: 59
209
+ C_xz: 60
210
+ C_yx: 61
211
+ C_yy: 62
212
+ C_yz: 63
213
+ C_zx: 64
214
+ C_zy: 65
215
+ C_zz: 66
216
+ log10_internal_energy: 67
217
+ log10_pressure: 68
218
+ log10_entropy: 69
219
+ well_base_path: /mnt/gpuxl/polymathic/the_well/datasets/
220
+ wandb_data_name: post_neutron_star_merger
221
+ module_parameters:
222
+ _target_: walrus.data.MixedWellDataModule
223
+ batch_size: 1
224
+ n_steps_input: 3
225
+ n_steps_output: 1
226
+ min_dt_stride: 1
227
+ max_dt_stride: 1
228
+ max_samples: 2000
229
+ well_dataset_info:
230
+ post_neutron_star_merger:
231
+ include_filters: []
232
+ exclude_filters: []
233
+ normalization_path: logged_stats.yaml
234
+ field_transforms:
235
+ density: torch.log10
236
+ temperature: torch.log10
237
+ pressure: torch.log10
238
+ entropy: torch.log10
239
+ internal_energy: torch.log10
240
+ auto_resume: true
241
+ folder_override: ''
242
+ checkpoint_override: ''
243
+ config_override: /mnt/home/polymathic/ceph/walrus_logging/platinum_checkpoints/extended_config.yaml
244
+ validation_mode: false
245
+ frozen_components:
246
+ - model
247
+ distribution:
248
+ distribution_type: fsdp
249
+ local_size: null
250
+ logger:
251
+ wandb: true
252
+ wandb_project_name: walrus_Finetuning_Runs
253
+ checkpoint:
254
+ _target_: walrus.trainer.checkpoints.CheckPointer
255
+ save_dir: /mnt/home/polymathic/ceph/walrus_logging/runs/WalrusFT_pns_1step_realglobalnorm-postn-delta-Isotr[Space-Adapt-]-AdamW-8e-05/finetune/0/checkpoints
256
+ load_checkpoint_path: null
257
+ coalesced_checkpoint_path: /mnt/home/polymathic/ceph/walrus_logging/platinum_checkpoints/final_base_model/walrus.pt
258
+ save_best: true
259
+ checkpoint_frequency: 20
260
+ align_fields: true
261
+ load_chkpt_after_finetuning_expansion: false
262
+ finetuning_mods:
263
+ learnable_rope: true
264
+ rope_per_axis: true
265
+ ape_shape:
266
+ - 17
267
+ - 17
268
+ - 16
269
+ experiment_dir: /mnt/home/polymathic/ceph/walrus_logging/runs