GPT-2-Large Latent Reasoning (Coconut)
NOTE: I quickly vibecoded this without much oversight for some quick experiments, I don't vouch for the quality of this model or its training process. You might want to try it for yourself a bit before using it.
GPT-2-large (774M) fine-tuned with Coconut (Chain of Continuous Thought) to perform multi-digit addition using latent reasoning — hidden states fed back as inputs instead of decoded to text.
Checkpoints
| File | Description | Val Accuracy (teacher-forced) |
|---|---|---|
stage0.pt |
Full text chain-of-thought | N/A (text mode) |
stage1.pt |
Last CoT step latent | 99.2% |
stage2.pt |
Last 2 CoT steps latent | Degraded (see below) |
stage3_alllatent.pt |
All CoT steps latent | Degraded (see below) |
How It Works
The model solves multi-digit addition problems with step-by-step carry propagation:
Problem: 347 + 285 =
CoT: 7+5=12 write 2 carry 1 | 4+8+1=13 write 3 carry 1 | 3+2+1=6 write 6
Answer: 632
Training follows a 4-stage curriculum. Stage 0 trains with full text CoT. Each subsequent stage replaces more CoT steps with latent continuous thought vectors (hidden states fed back as inputs instead of decoded to text).
Architecture details
- Base model:
gpt2-large(36 layers, 1280 hidden dim) - Special tokens:
<bot>(begin of thought),<sep>(step separator),<eot>(end of thought),<act>(activation marker)
Key Results
GPT-2-large dramatically improved over GPT-2-small for Coconut training:
- Stage 1: 99.2% teacher-forced accuracy (vs 69% for GPT-2-small)
- The latent reasoning at Stage 1 is nearly perfect
However, accuracy still degrades sharply as more steps become latent (Stages 2-3), consistent with the original Coconut paper's findings that latent reasoning requires extensive training.
Usage
These are raw PyTorch state dicts for GPT-2-large with a resized token embedding (4 extra special tokens). To load:
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2-large")
model.resize_token_embeddings(50261) # 50257 + 4 special tokens
import torch
state_dict = torch.load("stage1.pt", map_location="cpu")
model.load_state_dict(state_dict)
Training
- Data: 100K synthetic multi-digit addition problems (2-4 digits) with carry-propagation CoT
- Hardware: NVIDIA RTX 4090 (24GB)
- Optimizer: AdamW, lr=1e-5, cosine schedule with warmup
- Epochs: 3 per stage
- Gradient accumulation: 16 steps
Full training code: github.com/syvb/cocoracle
References
- Hao et al., "Training Large Language Models to Reason in a Continuous Latent Space" (2024). arXiv:2412.06769
Model tree for syvb/gpt-2-latent-reasoning
Base model
openai-community/gpt2-large