You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture

Hui Li1 · Jiayue Lyu1 · Fu-Yun Wang2 · Kaihui Cheng1

Siyu Zhu1,4,5 · Jingdong Wang3

1Fudan University   2The Chinese University of Hong Kong   3Baidu  

4Shanghai Innovation Institute   5Shanghai Academy of AI for Science

🌐 Project page🤗 Models📄 Paper


This is the official PyTorch implementation of MixFlow, a novel post-training approach for improving diffusion and flow matching models by alleviating the training-testing discrepancy (exposure bias).

Method Overview

We present a novel training approach, named MixFlow, for improving the training performance. Our approach is motivated by the Slow Flow phenomenon: the ground-truth interpolation that is the nearest to the generated noisy data at a given sampling timestep is observed to correspond to a higher-noise timestep (termed slowed timestep), i.e., the corresponding ground-truth timestep is slower than the sampling timestep. MixFlow leverages the interpolations at the slowed timesteps, named slowed interpolation mixture, for post-training the prediction network at each training timestep.

Implementation

The implementation is simple. For example, for MixFlow-RAE, 4 lines are added, and 1 line is modified in the file src/stage2/transport/transport.py:

def sample(self, x1):
        """Sampling x0 & t based on shape of x1 (if needed)
          Args:
            x1 - data point; [batch, *dim]
        """
        
        # ...
        if dist_options[0] == "uniform":
            t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
+           t = 1 - th.sqrt(t) # sample t from Beta(2,1)
        # ...
        return t, x0, x1

def training_losses(
        self, 
        model,  
        x1, 
+       gamma=0.4,  # mixture range coefficient
        model_kwargs=None
    ):
        # ... 
        
        t, x0, x1 = self.sample(x1)
+       t, _, ut = self.path_sampler.plan(t, x0, x1) # optional modification: remove the output xt, xt will be slowed interpolation 
+       mt = t + th.rand(*t.size(), device=t.device, dtype=t.dtype) * gamma * (1 - t) # sample slowed timestep mt from U[(1-gamma)t, t]
+       _, xt, __ = self.path_sampler.plan(mt, x0, x1) # compute slowed interpolation
        model_output = model(xt, t, **model_kwargs)

This repository includes four folders.

Each folder provides the training scripts, inference pipelines, and model weights for the following configurations:

  • MixFlow + RAE (Folder: MixFlow-RAE)
  • MixFlow + REPA (Folder: MixFlow-REPA)
  • MixFlow + SiT (Folder: MixFlow-SiT)
  • SD3.5-M + MixFlow (Folder: MixFlow-SD3.5) (TBD)

Results

ImageNet 256x256

Model Params FID (w/o cfg) FID (w/ cfg) Checkpoint
MixFlow + SiT-XL 675M 7.56 1.97 Download
MixFlow + REPA-XL 675M 5.00 1.22 Download
MixFlow + RAE-XL 839M 1.43 1.10 Download

ImageNet 512x512

Model Params FID (w/o cfg) FID (w/ cfg) Checkpoint
MixFlow + RAE-XL 839M 1.55 1.10 Download

Getting Started

1. Environment Setup

To set up our environment, please run:

git clone https://github.com//MixFlow.git
cd MixFlow

# Using conda
conda activate -n mixflow python=3.10 -y
conda activate mixflow

# Or using venv from the Python standard library (optional)
python3.10 -m venv .venv
source .venv/bin/activate

# Install uv
pip install uv

# Install PyTorch 2.8.0 with CUDA 12.4
uv pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu124

# Install other dependencies
uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb 
uv pip install "numpy<2" transformers einops omegaconf diffusers requests ftfy regex

2. Data, Model Download

Data preparation (ImageNet-1k), pretrained model download procedures should follow the settings detailed in the RAE documentation here.

The pretrained model is comprised of two parts: the encoder and decoder, and the DiTDH-XL. After downloading these checkpoints into the MixFlow-RAE/models folder, the next training step will use the pretrained DiTDH-XL as the starting checkpoint.

3. Post-train the RAE model with MixFlow

cd MixFlow-RAE
torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/train.py \
  --config configs/stage2/training/ImageNet256/DiTDH-XL_DINOv2-B.yaml \
  --data-path <imagenet_train_split> \
  --results-dir results/mixflow \
  --precision fp32 \
  --ckpt models/DiTs/Dinov2/wReg_base/ImageNet256/DiTDH-XL/stage2_model.pt # load the pretrained model

4. Evaluation

For distributed sampling, please run:

torchrun --standalone --nnodes=1 --nproc_per_node=N \
  src/sample_ddp.py \
  --config <sample_config> \
  --sample-dir samples \
  --precision fp32 \
  --label-sampling equal

Note that we utilize an autoguidance scale of 1.5 for both 256 and 512 resolutions. This differs from the original RAE settings, which use 1.42 for 256 resolution and 1.5 for 512 resolution.

After generating 50k samples, evaluate the results using the ADM evaluation suite. For detailed instructions, please refer to the RAE documentation here.

Citation

If you find this work useful, please cite:

@article{mixflow2025,
  title={MixFlow Training: Alleviating Exposure Bias with Slowed Interpolation Mixture},
  author={Hui Li and Jiayue Lyu and Fu-yun Wang and Kaihui Cheng and Siyu Zhu and Jingdong Wang},
  journal={arXiv preprint arXiv: 2512.19311},
  year={2025}
}

Acknowledgements

This codebase builds upon the following excellent works:

  • SiT - Scalable Interpolant Transformers
  • REPA - Representation Alignment
  • RAE - Representation Autoencoders
  • diffusion-pipe - A pipeline parallel training script for diffusion models.

license: mit

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support