Training time and compute infraestructure for MrBert-Legal
Hi! Thanks for releasing MrBERT-legal — very valuable work.
I’m interested in estimating the compute requirements to reproduce or approximate the domain adaptation phase. Could you share:
Compute & time
Total wall-clock training time (10 epochs, 8B tokens)
Number and type of GPUs/TPUs used
Cluster configuration (single-node vs multi-node)
Training efficiency
Effective batch size (tokens per batch or sequences × length × devices)
Throughput (tokens/sec or samples/sec, if available)
Sequence length used during training
Distributed setup
Training strategy (DDP, FSDP, DeepSpeed/ZeRO, etc.)
Precision (FP16, BF16, etc.)
This would be extremely helpful for estimating compute scaling and benchmarking legal-domain models. Thanks in advance!
Compute & Time
The domain adaptation was run for 80B tokens across 32 H100 GPUs (64 GB each), distributed over 8 nodes with 4 GPUs per node. Total wall-clock time was approximately 9 hours. This is slower than the raw throughput numbers would suggest, as we ran validation set evaluations at regular intervals, a deliberate trade-off that proved valuable for ablation studies.
Training Efficiency
We used a sequence length of 8,192 tokens with a global batch size of 512 sequences, yielding approximately 4M tokens per gradient update. The MLM probability was set to 0.3 during training and 0.15 during evaluation.
Real throughput was ~3.33M tokens/sec across the full cluster, not perfectly linear relative to the ~450K tokens/sec single-node baseline, but sufficient for our purposes.
Distributed Setup
We used the official ModernBERT repository for the distributed training configuration. Our setup closely follows the original context extension config, with one notable change: we replaced the scheduler with a cosine schedule with warmup. Training was run in amp_bf16 precision with the StableAdamW optimizer (lr: 3e-3, weight_decay: 1e-5).
The training config is included below for reference. Note that I no longer have access to the original files (I am no longer working officially at BSC), so this was reconstructed from partial documentation, please reach out if anything needs clarification.
model:
name: flex_bert
model_config:
normalization: layernorm
hidden_act: gelu
vocab_size: 256128
init_method: full_megatron
num_hidden_layers: 22
hidden_size: 768
intermediate_size: 1152
num_attention_heads: 12
attention_layer: rope
attention_probs_dropout_prob: 0.0
attn_out_bias: false
attn_out_dropout_prob: 0.1
attn_qkv_bias: false
bert_layer: prenorm
embed_dropout_prob: 0.0
embed_norm: true
final_norm: true
skip_first_prenorm: true
embedding_layer: sans_pos
loss_function: fa_cross_entropy
loss_kwargs:
reduction: mean
mlp_dropout_prob: 0.0
mlp_in_bias: false
mlp_layer: glu
mlp_out_bias: false
norm_kwargs:
eps: 1.0e-05
bias: false
head_pred_act: gelu
activation_function: gelu
padding: unpadded
rotary_emb_dim: null
rotary_emb_base: 160000.0
rotary_emb_scale_base: null
rotary_emb_interleaved: false
local_attn_rotary_emb_base: 10000.0
local_attn_rotary_emb_dim: null
allow_embedding_resizing: true
sliding_window: 128
global_attn_every_n_layers: 3
unpad_embeddings: true
compile_model: true
masked_prediction: true
pretrained_model_name: bert-base-uncased
tokenizer_name: ${tokenizer_name}
disable_train_metrics: true
data_local: path/to/data
data_remote: null
max_seq_len: 8192
tokenizer_name: /path/to/tokenizer
mlm_probability: 0.3
count_padding_tokens: false
run_name: MrBERT-legal
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: true
mlm_probability: ${mlm_probability}
streaming: true
drop_last: true
num_workers: 6
sequence_packing: false
eval_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: valid
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: false
mlm_probability: 0.15
streaming: true
drop_last: false
num_workers: 3
sequence_packing: false
scheduler:
name: cosine_with_warmup
t_warmup: 8_000_000_000tok
alpha_f: 0.001
t_max: ${max_duration}
optimizer:
name: decoupled_stableadamw
lr: 0.003
betas:
- 0.9
- 0.98
eps: 1.0e-06
weight_decay: 1.0e-05
filter_bias_norm_wd: true
log_grad_norm: true
max_duration: 80_000_000_000tok
eval_interval: 100ba
global_train_batch_size: 512
device_train_microbatch_size: 4
global_eval_batch_size: 512
device_eval_batch_size: 16
seed: 17
precision: amp_bf16
progress_bar: true
log_to_console: true
console_log_interval: 10ba
callbacks:
speed_monitor:
window_size: 20
lr_monitor: {}
scheduled_gc: {}
log_grad_norm:
batch_log_interval: 100
packing_efficiency:
log_interval: 5
loggers:
wandb:
project: modernbert
entity: bsc-langtech
save_interval: 100ba
save_num_checkpoints_to_keep: 50
save_folder: results/{run_name}
n_gpus: 32
device_train_batch_size: 16
device_eval_microbatch_size: 4
Nice! thanks.