LLaDA-Distilled-16L (Checkpoint 500)
A 16-layer distilled version of LLaDA-8B-Instruct for use as a draft model in speculative decoding.
Model Details
| Base model | GSAI-ML/LLaDA-8B-Instruct (32 layers, 8.0B params) |
| Student | 16 layers (first-8 + last-8 from teacher), 4.5B params |
| Architecture | Same as teacher โ full width (d_model=4096, 32 heads) |
| Param reduction | 43.5% fewer parameters |
| Training | Knowledge distillation (KL + CE on masked positions) |
| Checkpoint | Step 500 (intermediate, training ongoing) |
How It Was Made
Layer Subset Strategy
The student model was created by deep-copying the teacher and keeping only the first 8 and last 8 transformer layers (removing the middle 16). This preserves:
- Input embeddings and tokenizer (identical to teacher)
- Early layers (low-level features, syntax)
- Final layers (output projection, high-level semantics)
Distillation Training
- Objective: Masked diffusion distillation
- Forward diffusion: randomly mask 10-90% of tokens
- Loss = 0.7 ร KL(student โฅ teacher) + 0.3 ร CE(student, ground-truth) on masked positions only
- Dataset: 70% C4 (English web) + 30% StarCoderData (code)
- Tokens: ~100M tokens target (checkpoint at step 500)
- Hardware: 1ร NVIDIA A100 80GB
- Precision: bf16 mixed precision
- Optimizer: AdamW (lr=2e-5, cosine schedule, 5% warmup)
- Effective batch size: 32 (batch_size=4 ร grad_accum=8)
- Sequence length: 512
Intended Use
This model is designed as a draft model for speculative decoding with LLaDA-8B as the target model. It is NOT intended for standalone generation.
In speculative decoding:
- This draft model proposes token unmaskings cheaply (16 layers vs 32)
- The full LLaDA-8B teacher verifies proposals in one forward pass
- Accepted tokens are kept; rejected ones are resampled from the teacher
- Result: same output quality as the teacher, faster wall-clock time
How to Use
from transformers import AutoConfig
import sys
# Load via local Fast-dLLM class (recommended)
sys.path.insert(0, "path/to/Fast-dllm/llada")
from model.modeling_llada import LLaDAModelLM
from model.configuration_llada import LLaDAConfig
hf_config = AutoConfig.from_pretrained(
"jaygala223/llada-distilled-16L-checkpoint-500", trust_remote_code=True)
config = LLaDAConfig(**{k: v for k, v in hf_config.to_dict().items()
if k not in ("model_type", "transformers_version", "auto_map")})
config.use_cache = False
draft_model = LLaDAModelLM.from_pretrained(
"jaygala223/llada-distilled-16L-checkpoint-500",
trust_remote_code=True,
torch_dtype="bfloat16",
config=config,
)
draft_model.eval()
Limitations
- Intermediate checkpoint โ training was not complete at step 500. Further training will improve quality.
- Not standalone โ designed for speculative decoding draft, not direct text generation.
- Same tokenizer as LLaDA-8B-Instruct (128K vocab, mask_id=126336).
- Downloads last month
- 262
Model tree for jaygala223/llada-distilled-16L-checkpoint-500
Base model
GSAI-ML/LLaDA-8B-Instruct