LLaDA-Distilled-16L (Checkpoint 500)

A 16-layer distilled version of LLaDA-8B-Instruct for use as a draft model in speculative decoding.

Model Details

Base model GSAI-ML/LLaDA-8B-Instruct (32 layers, 8.0B params)
Student 16 layers (first-8 + last-8 from teacher), 4.5B params
Architecture Same as teacher โ€” full width (d_model=4096, 32 heads)
Param reduction 43.5% fewer parameters
Training Knowledge distillation (KL + CE on masked positions)
Checkpoint Step 500 (intermediate, training ongoing)

How It Was Made

Layer Subset Strategy

The student model was created by deep-copying the teacher and keeping only the first 8 and last 8 transformer layers (removing the middle 16). This preserves:

  • Input embeddings and tokenizer (identical to teacher)
  • Early layers (low-level features, syntax)
  • Final layers (output projection, high-level semantics)

Distillation Training

  • Objective: Masked diffusion distillation
    • Forward diffusion: randomly mask 10-90% of tokens
    • Loss = 0.7 ร— KL(student โˆฅ teacher) + 0.3 ร— CE(student, ground-truth) on masked positions only
  • Dataset: 70% C4 (English web) + 30% StarCoderData (code)
  • Tokens: ~100M tokens target (checkpoint at step 500)
  • Hardware: 1ร— NVIDIA A100 80GB
  • Precision: bf16 mixed precision
  • Optimizer: AdamW (lr=2e-5, cosine schedule, 5% warmup)
  • Effective batch size: 32 (batch_size=4 ร— grad_accum=8)
  • Sequence length: 512

Intended Use

This model is designed as a draft model for speculative decoding with LLaDA-8B as the target model. It is NOT intended for standalone generation.

In speculative decoding:

  1. This draft model proposes token unmaskings cheaply (16 layers vs 32)
  2. The full LLaDA-8B teacher verifies proposals in one forward pass
  3. Accepted tokens are kept; rejected ones are resampled from the teacher
  4. Result: same output quality as the teacher, faster wall-clock time

How to Use

from transformers import AutoConfig
import sys

# Load via local Fast-dLLM class (recommended)
sys.path.insert(0, "path/to/Fast-dllm/llada")
from model.modeling_llada import LLaDAModelLM
from model.configuration_llada import LLaDAConfig

hf_config = AutoConfig.from_pretrained(
    "jaygala223/llada-distilled-16L-checkpoint-500", trust_remote_code=True)
config = LLaDAConfig(**{k: v for k, v in hf_config.to_dict().items()
                        if k not in ("model_type", "transformers_version", "auto_map")})
config.use_cache = False

draft_model = LLaDAModelLM.from_pretrained(
    "jaygala223/llada-distilled-16L-checkpoint-500",
    trust_remote_code=True,
    torch_dtype="bfloat16",
    config=config,
)
draft_model.eval()

Limitations

  • Intermediate checkpoint โ€” training was not complete at step 500. Further training will improve quality.
  • Not standalone โ€” designed for speculative decoding draft, not direct text generation.
  • Same tokenizer as LLaDA-8B-Instruct (128K vocab, mask_id=126336).
Downloads last month
262
Safetensors
Model size
5B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for jaygala223/llada-distilled-16L-checkpoint-500

Finetuned
(28)
this model