File size: 12,386 Bytes
25fff03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
# Copyright (c) 2021, EleutherAI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
import wandb
from megatron import mpu, print_rank_0
from megatron.utils import report_memory
class Tee:
"""Duplicate output to both stdout/err and file"""
def __init__(self, file, err=False):
self.file = open(file, "w")
self.err = err
if not err:
self.std = sys.stdout
sys.stdout = self
else:
self.std = sys.stderr
sys.stderr = self
def __del__(self):
if not self.err:
sys.stdout = self.std
else:
sys.stderr = self.std
self.file.close()
def write(self, data):
try:
self.file.write(data)
except OSError:
pass
try:
self.std.write(data)
except OSError:
pass
def flush(self):
try:
self.file.flush()
except OSError:
pass
def human_readable_flops(num):
for unit in [
"",
"KFLOPS",
"MFLOPS",
"GFLOPS",
"TFLOPS",
"PFLOPS",
"EFLOPS",
"ZFLOPS",
]:
if abs(num) < 1000.0:
return "%3.1f%s" % (num, unit)
num /= 1000.0
return "%.1f%s" % (num, "Yi")
def get_flops(neox_args, model, iter_time_s):
world_size = torch.distributed.get_world_size()
ff = model.total_params * 6
attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60
flops = (
neox_args.train_batch_size
* neox_args.seq_length
* (ff + attn)
/ (iter_time_s * world_size)
)
return flops
def training_log(
neox_args,
timers,
loss_dict,
total_loss_dict,
learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
model,
optimizer,
noise_scale_logger,
):
"""Log training information such as losses, timing, etc."""
# Update losses.
skipped_iters_key = "skipped iterations"
total_loss_dict[skipped_iters_key] = (
total_loss_dict.get(skipped_iters_key, 0) + skipped_iter
)
got_nan_key = "got nan"
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.0) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float("inf") or value == -float("inf") or value != value
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(got_nan_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
if not neox_args.is_pipe_parallel:
add_to_logging("forward")
add_to_logging("backward")
add_to_logging("backward-backward")
add_to_logging("backward-allreduce")
add_to_logging("backward-master-grad")
add_to_logging("backward-clip-grad")
add_to_logging("optimizer")
add_to_logging("batch generator")
# Log timer info to tensorboard and wandb
normalizer = iteration % neox_args.log_interval
if normalizer == 0:
normalizer = neox_args.log_interval
if torch.distributed.get_rank() == 0:
timers.write(
names=timers_to_log, iteration=iteration, normalizer=normalizer
)
else:
# with pipeline parallel, the megatron timers are overridden by the deepspeed ones.
# Try to grab timer values from model engine. Only recently added to deeperspeed, so check that the engine
# has that attribute first
if hasattr(model, "timer_values") and model.timer_values is not None:
if (
model.wall_clock_breakdown()
and model.global_steps % model.steps_per_print() == 0
):
timer_values = model.timer_values
# deepspeed already logs to tensorboard / prints values, so just log to wandb
if neox_args.use_wandb and torch.distributed.get_rank() == 0:
for key in timer_values:
tb_wandb_log(
f"timers/{key}",
timer_values[key],
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
# write losses, lr, etc. every step
tb_wandb_log(
"train/learning_rate",
learning_rate,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
for key in loss_dict:
tb_wandb_log(
f'train/{key.replace(" ", "_")}',
loss_dict[key],
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
if neox_args.fp16:
tb_wandb_log(
f"train/loss_scale",
loss_scale,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
# log gradient noise scale
if neox_args.log_gradient_noise_scale:
if noise_scale_logger.noise_scale is not None:
tb_wandb_log(
f"train/noise_scale",
noise_scale_logger.noise_scale,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
# (optional) Log optimizer states to wandb / tb every step
if neox_args.log_optimizer_states:
for k, v in optimizer.state_dict()["optimizer_state_dict"]["state"].items():
for ki, vi in v.items(): # step, module
if ki != "step":
opt_state_norm = torch.norm(vi) if hasattr(vi, "dim") else vi
tb_wandb_log(
f"optimizer_state_norms/{k}_{ki}",
opt_state_norm,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
# (optional) Log grad/param norms to wandb / tb every step
if (
neox_args.log_grad_pct_zeros
or neox_args.log_grad_norm
or neox_args.log_param_norm
):
if neox_args.log_grad_pct_zeros or neox_args.log_grad_norm:
model.store_gradients = True # start storing gradients
for i, (name, param) in enumerate(model.module.named_parameters()):
if neox_args.log_grad_pct_zeros:
if (
hasattr(model, "stored_gradients")
and model.stored_gradients is not None
):
grad = model.stored_gradients[i]
if grad is not None:
tb_wandb_log(
f"pct_grad_zeros/{name}",
(grad == 0).float().mean().item() * 100,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
all_ranks=True,
)
if neox_args.log_grad_norm:
if (
hasattr(model, "stored_gradients")
and model.stored_gradients is not None
):
grad = model.stored_gradients[i]
if grad is not None:
tb_wandb_log(
f"gradient_norms/{name}",
torch.norm(grad),
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
all_ranks=True,
)
if neox_args.log_param_norm:
tb_wandb_log(
f"parameter_norms/{name}",
torch.norm(param),
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
all_ranks=True,
)
if iteration % neox_args.log_interval == 0:
# log other stuff every neox_args.log_interval iters
elapsed_time = timers("interval time").elapsed()
iteration_time = elapsed_time / neox_args.log_interval
samples_per_sec = neox_args.train_batch_size / iteration_time
log_string = " samples/sec: {:.3f} |".format(samples_per_sec)
tb_wandb_log(
"runtime/samples_per_sec",
samples_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
tb_wandb_log(
"runtime/iteration_time",
iteration_time,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
log_string += " iteration {:8d}/{:8d} |".format(
iteration, neox_args.train_iters
)
log_string += " elapsed time per iteration (ms): {:.1f} |".format(
elapsed_time * 1000.0 / neox_args.log_interval
)
log_string += " learning rate: {:.3E} |".format(learning_rate)
num_iterations = max(
1, neox_args.log_interval - total_loss_dict[skipped_iters_key]
)
# log tflop / gpu
flops_per_s_per_gpu = get_flops(
neox_args=neox_args, model=model, iter_time_s=iteration_time
)
log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
)
tb_wandb_log(
"runtime/flops_per_sec_per_gpu",
flops_per_s_per_gpu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
v = (
total_loss_dict[key].item()
if hasattr(total_loss_dict[key], "item")
else total_loss_dict[key]
)
avg = v / float(num_iterations)
log_string += " {}: {:.6E} |".format(key, avg)
total_loss_dict[key] = 0.0
if neox_args.precision == "fp16":
log_string += " loss scale: {:.1f} |".format(loss_scale)
log_string += " number of skipped iterations: {:3d} |".format(
total_loss_dict[skipped_iters_key]
)
log_string += " number of nan iterations: {:3d} |".format(
total_loss_dict[got_nan_key]
)
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string)
if report_memory_flag:
report_memory("after {} iterations".format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=neox_args.log_interval)
return report_memory_flag
def tb_wandb_log(
key, value, iteration_no, use_wandb, tensorboard_writer=None, all_ranks=False
):
# logs to both tb and wandb (if present) from the zeroth rank
do_log = torch.distributed.get_rank() == 0 or all_ranks
if do_log and value is not None:
if tensorboard_writer:
tensorboard_writer.add_scalar(key, value, iteration_no)
if use_wandb:
wandb.log({key: value}, step=iteration_no)
|