MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture
Hui Li1 · Jiayue Lyu1 · Fu-Yun Wang2 · Kaihui Cheng1
Siyu Zhu1,4,5 · Jingdong Wang3
1Fudan University 2The Chinese University of Hong Kong 3Baidu
4Shanghai Innovation Institute 5Shanghai Academy of AI for Science
This is the official PyTorch implementation of MixFlow, a novel post-training approach for improving diffusion and flow matching models by alleviating the training-testing discrepancy (exposure bias).
Method Overview
We present a novel training approach, named MixFlow, for improving the training performance. Our approach is motivated by the Slow Flow phenomenon: the ground-truth interpolation that is the nearest to the generated noisy data at a given sampling timestep is observed to correspond to a higher-noise timestep (termed slowed timestep), i.e., the corresponding ground-truth timestep is slower than the sampling timestep. MixFlow leverages the interpolations at the slowed timesteps, named slowed interpolation mixture, for post-training the prediction network at each training timestep.
Implementation
The implementation is simple. For example, for MixFlow-RAE, 4 lines are added, and 1 line is modified in the file src/stage2/transport/transport.py:
def sample(self, x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
# ...
if dist_options[0] == "uniform":
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
+ t = 1 - th.sqrt(t) # sample t from Beta(2,1)
# ...
return t, x0, x1
def training_losses(
self,
model,
x1,
+ gamma=0.4, # mixture range coefficient
model_kwargs=None
):
# ...
t, x0, x1 = self.sample(x1)
+ t, _, ut = self.path_sampler.plan(t, x0, x1) # optional modification: remove the output xt, xt will be slowed interpolation
+ mt = t + th.rand(*t.size(), device=t.device, dtype=t.dtype) * gamma * (1 - t) # sample slowed timestep mt from U[(1-gamma)t, t]
+ _, xt, __ = self.path_sampler.plan(mt, x0, x1) # compute slowed interpolation
model_output = model(xt, t, **model_kwargs)
This repository includes four folders.
Each folder provides the training scripts, inference pipelines, and model weights for the following configurations:
- MixFlow + RAE (Folder:
MixFlow-RAE) - MixFlow + REPA (Folder:
MixFlow-REPA) - MixFlow + SiT (Folder:
MixFlow-SiT) - SD3.5-M + MixFlow (Folder:
MixFlow-SD3.5) (TBD)
Results
ImageNet 256x256
| Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint |
|---|---|---|---|---|
| MixFlow + SiT-XL | 675M | 7.56 | 1.97 | Download |
| MixFlow + REPA-XL | 675M | 5.00 | 1.22 | Download |
| MixFlow + RAE-XL | 839M | 1.43 | 1.10 | Download |
ImageNet 512x512
| Model | Params | FID (w/o cfg) | FID (w/ cfg) | Checkpoint |
|---|---|---|---|---|
| MixFlow + RAE-XL | 839M | 1.55 | 1.10 | Download |
Getting Started
1. Environment Setup
To set up our environment, please run:
git clone https://github.com//MixFlow.git
cd MixFlow
# Using conda
conda activate -n mixflow python=3.10 -y
conda activate mixflow
# Or using venv from the Python standard library (optional)
python3.10 -m venv .venv
source .venv/bin/activate
# Install uv
pip install uv
# Install PyTorch 2.8.0 with CUDA 12.4
uv pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu124
# Install other dependencies
uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb
uv pip install "numpy<2" transformers einops omegaconf diffusers requests ftfy regex
2. Data, Model Download
Data preparation (ImageNet-1k), pretrained model download procedures should follow the settings detailed in the RAE documentation here.
The pretrained model is comprised of two parts: the encoder and decoder, and the DiTDH-XL. After downloading these checkpoints into the MixFlow-RAE/models folder, the next training step will use the pretrained DiTDH-XL as the starting checkpoint.
3. Post-train the RAE model with MixFlow
cd MixFlow-RAE
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/train.py \
--config configs/stage2/training/ImageNet256/DiTDH-XL_DINOv2-B.yaml \
--data-path <imagenet_train_split> \
--results-dir results/mixflow \
--precision fp32 \
--ckpt models/DiTs/Dinov2/wReg_base/ImageNet256/DiTDH-XL/stage2_model.pt # load the pretrained model
4. Evaluation
For distributed sampling, please run:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/sample_ddp.py \
--config <sample_config> \
--sample-dir samples \
--precision fp32 \
--label-sampling equal
Note that we utilize an autoguidance scale of 1.5 for both 256 and 512 resolutions. This differs from the original RAE settings, which use 1.42 for 256 resolution and 1.5 for 512 resolution.
After generating 50k samples, evaluate the results using the ADM evaluation suite. For detailed instructions, please refer to the RAE documentation here.
Citation
If you find this work useful, please cite:
@article{mixflow2025,
title={MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture},
author={Hui Li and Jiayue Lyu and Fu-yun Wang and Kaihui Cheng and Siyu Zhu and Jingdong Wang},
journal={arXiv preprint arXiv: 2512.19311},
year={2025}
}
Acknowledgements
This codebase builds upon the following excellent works:
- SiT - Scalable Interpolant Transformers
- REPA - Representation Alignment
- RAE - Representation Autoencoders
- diffusion-pipe - A pipeline parallel training script for diffusion models.