New logo
Browse files- images/bertin.png +0 -0
- run_mlm_flax_stream.py +55 -3
images/bertin.png
CHANGED
|
|
run_mlm_flax_stream.py
CHANGED
|
@@ -25,6 +25,7 @@ import json
|
|
| 25 |
import os
|
| 26 |
import shutil
|
| 27 |
import sys
|
|
|
|
| 28 |
import time
|
| 29 |
from collections import defaultdict
|
| 30 |
from dataclasses import dataclass, field
|
|
@@ -60,6 +61,8 @@ from transformers import (
|
|
| 60 |
TrainingArguments,
|
| 61 |
is_tensorboard_available,
|
| 62 |
set_seed,
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
|
@@ -376,6 +379,27 @@ def rotate_checkpoints(path, max_checkpoints=5):
|
|
| 376 |
os.remove(path_to_delete)
|
| 377 |
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
if __name__ == "__main__":
|
| 380 |
# See all possible arguments in src/transformers/training_args.py
|
| 381 |
# or by passing the --help flag to this script.
|
|
@@ -749,7 +773,8 @@ if __name__ == "__main__":
|
|
| 749 |
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 750 |
|
| 751 |
# Update progress bar
|
| 752 |
-
steps.desc = f"Step... ({step
|
|
|
|
| 753 |
|
| 754 |
if has_tensorboard and jax.process_index() == 0:
|
| 755 |
write_eval_metric(summary_writer, eval_metrics, step)
|
|
@@ -762,8 +787,7 @@ if __name__ == "__main__":
|
|
| 762 |
model.save_pretrained(
|
| 763 |
training_args.output_dir,
|
| 764 |
params=params,
|
| 765 |
-
push_to_hub=
|
| 766 |
-
commit_message=f"Saving weights and logs of step {step + 1}",
|
| 767 |
)
|
| 768 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
| 769 |
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
|
@@ -774,6 +798,34 @@ if __name__ == "__main__":
|
|
| 774 |
Path(training_args.output_dir) / "checkpoints",
|
| 775 |
max_checkpoints=training_args.save_total_limit
|
| 776 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
# update tqdm bar
|
| 779 |
steps.update(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
import os
|
| 26 |
import shutil
|
| 27 |
import sys
|
| 28 |
+
import tempfile
|
| 29 |
import time
|
| 30 |
from collections import defaultdict
|
| 31 |
from dataclasses import dataclass, field
|
|
|
|
| 61 |
TrainingArguments,
|
| 62 |
is_tensorboard_available,
|
| 63 |
set_seed,
|
| 64 |
+
FlaxRobertaForMaskedLM,
|
| 65 |
+
RobertaForMaskedLM,
|
| 66 |
)
|
| 67 |
|
| 68 |
|
|
|
|
| 379 |
os.remove(path_to_delete)
|
| 380 |
|
| 381 |
|
| 382 |
+
def to_f32(t):
|
| 383 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def convert(output_dir, destination_dir="./"):
|
| 387 |
+
shutil.copyfile(Path(output_dir) / "flax_model.msgpack", destination_dir)
|
| 388 |
+
shutil.copyfile(Path(output_dir) / "config.json", destination_dir)
|
| 389 |
+
# Saving extra files from config.json and tokenizer.json files
|
| 390 |
+
tokenizer = AutoTokenizer.from_pretrained(destination_dir)
|
| 391 |
+
tokenizer.save_pretrained(destination_dir)
|
| 392 |
+
|
| 393 |
+
# Temporary saving bfloat16 Flax model into float32
|
| 394 |
+
tmp = tempfile.mkdtemp()
|
| 395 |
+
flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir)
|
| 396 |
+
flax_model.params = to_f32(flax_model.params)
|
| 397 |
+
flax_model.save_pretrained(tmp)
|
| 398 |
+
# Converting float32 Flax to PyTorch
|
| 399 |
+
model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
|
| 400 |
+
model.save_pretrained(destination_dir, save_config=False)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
if __name__ == "__main__":
|
| 404 |
# See all possible arguments in src/transformers/training_args.py
|
| 405 |
# or by passing the --help flag to this script.
|
|
|
|
| 773 |
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 774 |
|
| 775 |
# Update progress bar
|
| 776 |
+
steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
| 777 |
+
last_desc = steps.desc
|
| 778 |
|
| 779 |
if has_tensorboard and jax.process_index() == 0:
|
| 780 |
write_eval_metric(summary_writer, eval_metrics, step)
|
|
|
|
| 787 |
model.save_pretrained(
|
| 788 |
training_args.output_dir,
|
| 789 |
params=params,
|
| 790 |
+
push_to_hub=False,
|
|
|
|
| 791 |
)
|
| 792 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
| 793 |
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
|
|
|
| 798 |
Path(training_args.output_dir) / "checkpoints",
|
| 799 |
max_checkpoints=training_args.save_total_limit
|
| 800 |
)
|
| 801 |
+
convert(training_args.output_dir, "./")
|
| 802 |
+
model.save_pretrained(
|
| 803 |
+
training_args.output_dir,
|
| 804 |
+
params=params,
|
| 805 |
+
push_to_hub=training_args.push_to_hub,
|
| 806 |
+
commit_message=last_desc,
|
| 807 |
+
)
|
| 808 |
|
| 809 |
# update tqdm bar
|
| 810 |
steps.update(1)
|
| 811 |
+
|
| 812 |
+
if jax.process_index() == 0:
|
| 813 |
+
logger.info(f"Saving checkpoint at {step} steps")
|
| 814 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 815 |
+
model.save_pretrained(
|
| 816 |
+
training_args.output_dir,
|
| 817 |
+
params=params,
|
| 818 |
+
push_to_hub=False,
|
| 819 |
+
)
|
| 820 |
+
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
| 821 |
+
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
| 822 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 823 |
+
model.save_pretrained(checkpoints_dir, params=params)
|
| 824 |
+
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
| 825 |
+
convert(training_args.output_dir, "./")
|
| 826 |
+
model.save_pretrained(
|
| 827 |
+
training_args.output_dir,
|
| 828 |
+
params=params,
|
| 829 |
+
push_to_hub=training_args.push_to_hub,
|
| 830 |
+
commit_message=last_desc,
|
| 831 |
+
)
|