GPT-2-Large Latent Reasoning (Coconut)

NOTE: I quickly vibecoded this without much oversight for some quick experiments, I don't vouch for the quality of this model or its training process. You might want to try it for yourself a bit before using it.

GPT-2-large (774M) fine-tuned with Coconut (Chain of Continuous Thought) to perform multi-digit addition using latent reasoning — hidden states fed back as inputs instead of decoded to text.

Checkpoints

File Description Val Accuracy (teacher-forced)
stage0.pt Full text chain-of-thought N/A (text mode)
stage1.pt Last CoT step latent 99.2%
stage2.pt Last 2 CoT steps latent Degraded (see below)
stage3_alllatent.pt All CoT steps latent Degraded (see below)

How It Works

The model solves multi-digit addition problems with step-by-step carry propagation:

Problem: 347 + 285 =
CoT: 7+5=12 write 2 carry 1 | 4+8+1=13 write 3 carry 1 | 3+2+1=6 write 6
Answer: 632

Training follows a 4-stage curriculum. Stage 0 trains with full text CoT. Each subsequent stage replaces more CoT steps with latent continuous thought vectors (hidden states fed back as inputs instead of decoded to text).

Architecture details

  • Base model: gpt2-large (36 layers, 1280 hidden dim)
  • Special tokens: <bot> (begin of thought), <sep> (step separator), <eot> (end of thought), <act> (activation marker)

Key Results

GPT-2-large dramatically improved over GPT-2-small for Coconut training:

  • Stage 1: 99.2% teacher-forced accuracy (vs 69% for GPT-2-small)
  • The latent reasoning at Stage 1 is nearly perfect

However, accuracy still degrades sharply as more steps become latent (Stages 2-3), consistent with the original Coconut paper's findings that latent reasoning requires extensive training.

Usage

These are raw PyTorch state dicts for GPT-2-large with a resized token embedding (4 extra special tokens). To load:

from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2-large")
model.resize_token_embeddings(50261)  # 50257 + 4 special tokens

import torch
state_dict = torch.load("stage1.pt", map_location="cpu")
model.load_state_dict(state_dict)

Training

  • Data: 100K synthetic multi-digit addition problems (2-4 digits) with carry-propagation CoT
  • Hardware: NVIDIA RTX 4090 (24GB)
  • Optimizer: AdamW, lr=1e-5, cosine schedule with warmup
  • Epochs: 3 per stage
  • Gradient accumulation: 16 steps

Full training code: github.com/syvb/cocoracle

References

  • Hao et al., "Training Large Language Models to Reason in a Continuous Latent Space" (2024). arXiv:2412.06769
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for syvb/gpt-2-latent-reasoning

Finetuned
(131)
this model

Paper for syvb/gpt-2-latent-reasoning