Instructions to use dynatrace-oss/chunkable-mamba2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use dynatrace-oss/chunkable-mamba2 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("dynatrace-oss/chunkable-mamba2", dtype="auto") - Notebooks
- Google Colab
- Kaggle
feat: chunkable mamba2 model code
Browse files- .gitattributes +35 -0
- README.md +63 -0
- chunkable_ssd_combined.py +511 -0
- configuration_chunkable_mamba2.py +12 -0
- modeling_chunkable_mamba2.py +131 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- transformers
|
| 6 |
+
- mamba2
|
| 7 |
+
- vertical-chunking
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# chunkable-mamba2
|
| 11 |
+
|
| 12 |
+
Custom [Mamba2](https://arxiv.org/abs/2405.21060) model and configuration classes for 🤗 Transformers that add support for **vertically chunked inference**, which processes input sequences in fixed-size vertical chunks through all model layers with constant memory usage, regardless of sequence length.
|
| 13 |
+
|
| 14 |
+
## What this repository provides
|
| 15 |
+
|
| 16 |
+
- **`ChunkableMamba2Config`:** extends `Mamba2Config` with a `use_mem_eff_path` option for the memory-efficient CUDA kernel path.
|
| 17 |
+
- **`ChunkableMamba2Model`:** extends `Mamba2Model` with a chunkable mixer and cache that correctly propagate the recurrent states across vertical chunks (simultaneous `seq_idx` + `initial_states` support).
|
| 18 |
+
- **`chunkable_mamba_split_conv1d_scan_combined`:** modified `mamba_split_conv1d_scan_combined` kernel wrapper that passes cache parameters through the SSD scan so that conv and SSM states are properly initialized and exported during chunked inference.
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
|
| 22 |
+
This repository is designed to be referenced directly from Hugging Face model configs via `auto_map`, so that models can be loaded with `trust_remote_code=True` without any local installation:
|
| 23 |
+
|
| 24 |
+
```json
|
| 25 |
+
"auto_map": {
|
| 26 |
+
"AutoConfig": "dynatrace-oss/chunkable-mamba2--configuration_chunkable_mamba2.ChunkableMamba2Config",
|
| 27 |
+
"AutoModel": "dynatrace-oss/chunkable-mamba2--modeling_chunkable_mamba2.ChunkableMamba2Model"
|
| 28 |
+
}
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Models
|
| 32 |
+
|
| 33 |
+
This code was created for the following embedding models:
|
| 34 |
+
|
| 35 |
+
- [dynatrace-oss/llama-embed-mamba2-7b](https://huggingface.co/dynatrace-oss/llama-embed-mamba2-7b)
|
| 36 |
+
- [dynatrace-oss/llama-embed-mamba2-1.3b](https://huggingface.co/dynatrace-oss/llama-embed-mamba2-1.3b)
|
| 37 |
+
|
| 38 |
+
## Requirements
|
| 39 |
+
|
| 40 |
+
> [!IMPORTANT]
|
| 41 |
+
> Requires `transformers>=5.5.0` due to a breaking change to the cache of Mamba2 introduced in `v5.5.0` ([transformers#44950](https://github.com/huggingface/transformers/pull/44950)).
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
pip install transformers kernels einops
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Open Source Integration Roadmap
|
| 48 |
+
|
| 49 |
+
Our goal is to integrate all necessary changes to simplify the adoption of vertically chunked inference for other models:
|
| 50 |
+
|
| 51 |
+
> [!Note]
|
| 52 |
+
> ⚪ Planned | 🟡 In Progress | 🟢 Integrated
|
| 53 |
+
|
| 54 |
+
- ⚪ **causal-conv1d:** Enable simultaneous `seq_idx` + `initial_states` (required for recurrent processing of chunks with left padding)
|
| 55 |
+
- ⚪ **mamba-ssm:** Use `seq_idx` + `initial_states` in `mamba_split_conv1d_scan_combined` and export final states
|
| 56 |
+
- ⚪ **kernels-community:** Propagate changes in `causal-conv1d` and `mamba-ssm` to their kernel hub equivalents in the `kernels-community` repositories
|
| 57 |
+
- ⚪ **transformers:** Use updated `mamba_split_conv1d_scan_combined` with cache params during inference (currently only used during training, not configurable, problems with left padding)
|
| 58 |
+
|
| 59 |
+
*This list will be updated as integration progresses.*
|
| 60 |
+
|
| 61 |
+
## License
|
| 62 |
+
|
| 63 |
+
Apache-2.0
|
chunkable_ssd_combined.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
"""We want triton==2.1.0 or 2.2.0 for this"""
|
| 4 |
+
|
| 5 |
+
from packaging import version
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import triton
|
| 11 |
+
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from transformers.integrations.hub_kernels import get_kernel
|
| 15 |
+
|
| 16 |
+
# Fixed revisions because kernels after 2026-04-14 do not expose the functions we need anymore.
|
| 17 |
+
causal_conv1d = get_kernel("kernels-community/causal-conv1d", revision="dc7072f0e9d799b247a2517a909ebb209d50bea0")
|
| 18 |
+
mamba_ssm = get_kernel("kernels-community/mamba-ssm", revision="00b2ecd499379f9bcf969b6796e53bc867f4ad38")
|
| 19 |
+
|
| 20 |
+
causal_conv1d_fwd_function = causal_conv1d.cpp_functions.causal_conv1d_fwd_function
|
| 21 |
+
causal_conv1d_bwd_function = causal_conv1d.cpp_functions.causal_conv1d_bwd_function
|
| 22 |
+
|
| 23 |
+
custom_fwd = mamba_ssm.utils.torch.custom_fwd
|
| 24 |
+
custom_bwd = mamba_ssm.utils.torch.custom_bwd
|
| 25 |
+
_layer_norm_fwd = mamba_ssm.ops.triton.layernorm_gated._layer_norm_fwd
|
| 26 |
+
_layer_norm_bwd = mamba_ssm.ops.triton.layernorm_gated._layer_norm_bwd
|
| 27 |
+
_swiglu_fwd = mamba_ssm.ops.triton.k_activations._swiglu_fwd
|
| 28 |
+
_swiglu_bwd = mamba_ssm.ops.triton.k_activations._swiglu_bwd
|
| 29 |
+
rearrange_and_update_stride = mamba_ssm.ops.triton.ssd_combined.rearrange_and_update_stride
|
| 30 |
+
_mamba_chunk_scan_combined_fwd = mamba_ssm.ops.triton.ssd_combined._mamba_chunk_scan_combined_fwd
|
| 31 |
+
_mamba_chunk_scan_combined_bwd = mamba_ssm.ops.triton.ssd_combined._mamba_chunk_scan_combined_bwd
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ChunkableMambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
| 38 |
+
@staticmethod
|
| 39 |
+
@custom_fwd
|
| 40 |
+
def forward(
|
| 41 |
+
ctx,
|
| 42 |
+
zxbcdt,
|
| 43 |
+
conv1d_weight,
|
| 44 |
+
conv1d_bias,
|
| 45 |
+
dt_bias,
|
| 46 |
+
A,
|
| 47 |
+
D,
|
| 48 |
+
chunk_size,
|
| 49 |
+
initial_conv_states=None,
|
| 50 |
+
initial_ssm_states=None,
|
| 51 |
+
seq_idx=None,
|
| 52 |
+
dt_limit=(0.0, float("inf")),
|
| 53 |
+
return_final_states=False,
|
| 54 |
+
activation="silu",
|
| 55 |
+
rmsnorm_weight=None,
|
| 56 |
+
rmsnorm_eps=1e-6,
|
| 57 |
+
outproj_weight=None,
|
| 58 |
+
outproj_bias=None,
|
| 59 |
+
headdim=None,
|
| 60 |
+
ngroups=1,
|
| 61 |
+
norm_before_gate=True,
|
| 62 |
+
):
|
| 63 |
+
assert activation in [None, "silu", "swish"]
|
| 64 |
+
if D.dim() == 1:
|
| 65 |
+
assert headdim is not None
|
| 66 |
+
(nheads,) = D.shape
|
| 67 |
+
else:
|
| 68 |
+
nheads, headdim = D.shape
|
| 69 |
+
batch, seqlen, _ = zxbcdt.shape
|
| 70 |
+
dim = nheads * headdim
|
| 71 |
+
assert nheads % ngroups == 0
|
| 72 |
+
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
| 73 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
| 74 |
+
assert d_nonssm >= 0
|
| 75 |
+
assert zxbcdt.shape == (
|
| 76 |
+
batch,
|
| 77 |
+
seqlen,
|
| 78 |
+
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
| 79 |
+
)
|
| 80 |
+
assert dt_bias.shape == (nheads,)
|
| 81 |
+
assert A.shape == (nheads,)
|
| 82 |
+
zx0, z, xBC, dt = torch.split(
|
| 83 |
+
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
| 84 |
+
)
|
| 85 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 86 |
+
final_conv_states = (
|
| 87 |
+
torch.empty(
|
| 88 |
+
(batch, conv1d_weight.shape[1] - 1, dim + ngroups * dstate * 2),
|
| 89 |
+
device=xBC.device,
|
| 90 |
+
dtype=xBC.dtype,
|
| 91 |
+
).transpose(1, 2)
|
| 92 |
+
if return_final_states
|
| 93 |
+
else None
|
| 94 |
+
)
|
| 95 |
+
# Workaround because causal_conv1d_fwd_function currently does not support seq_idx when initial_conv_states is not None.
|
| 96 |
+
# Additionally, there is a bug in causal_conv1d_fwd_function when seq_idx is used causing illegal memory access:
|
| 97 |
+
# - Issue: https://github.com/Dao-AILab/causal-conv1d/issues/67
|
| 98 |
+
# - PR: https://github.com/Dao-AILab/causal-conv1d/pull/101
|
| 99 |
+
if seq_idx is not None and initial_conv_states is not None:
|
| 100 |
+
xBC = xBC * (seq_idx.unsqueeze(-1) >= 0).to(xBC.dtype)
|
| 101 |
+
xBC_conv = rearrange(
|
| 102 |
+
causal_conv1d_fwd_function(
|
| 103 |
+
rearrange_and_update_stride(xBC, "b s d -> b d s"),
|
| 104 |
+
conv1d_weight,
|
| 105 |
+
conv1d_bias,
|
| 106 |
+
None,
|
| 107 |
+
initial_conv_states,
|
| 108 |
+
final_conv_states,
|
| 109 |
+
activation in ["silu", "swish"],
|
| 110 |
+
),
|
| 111 |
+
"b d s -> b s d",
|
| 112 |
+
)
|
| 113 |
+
if seq_idx is not None and initial_conv_states is not None:
|
| 114 |
+
xBC_conv = xBC_conv * (seq_idx.unsqueeze(-1) >= 0).to(xBC_conv.dtype)
|
| 115 |
+
x, B, C = torch.split(
|
| 116 |
+
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
| 117 |
+
)
|
| 118 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 119 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 120 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 121 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 122 |
+
if rmsnorm_weight is None:
|
| 123 |
+
out, out_x, dt_out, dA_cumsum, states, final_ssm_states = (
|
| 124 |
+
_mamba_chunk_scan_combined_fwd(
|
| 125 |
+
x,
|
| 126 |
+
dt,
|
| 127 |
+
A,
|
| 128 |
+
B,
|
| 129 |
+
C,
|
| 130 |
+
chunk_size=chunk_size,
|
| 131 |
+
D=D,
|
| 132 |
+
z=z,
|
| 133 |
+
dt_bias=dt_bias,
|
| 134 |
+
initial_states=initial_ssm_states,
|
| 135 |
+
seq_idx=seq_idx,
|
| 136 |
+
dt_softplus=True,
|
| 137 |
+
dt_limit=dt_limit,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
out = rearrange(out, "b s h p -> b s (h p)")
|
| 141 |
+
rstd = None
|
| 142 |
+
if d_nonssm > 0:
|
| 143 |
+
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
| 144 |
+
else:
|
| 145 |
+
out_x, _, dt_out, dA_cumsum, states, final_ssm_states = (
|
| 146 |
+
_mamba_chunk_scan_combined_fwd(
|
| 147 |
+
x,
|
| 148 |
+
dt,
|
| 149 |
+
A,
|
| 150 |
+
B,
|
| 151 |
+
C,
|
| 152 |
+
chunk_size=chunk_size,
|
| 153 |
+
D=D,
|
| 154 |
+
z=None,
|
| 155 |
+
dt_bias=dt_bias,
|
| 156 |
+
initial_states=initial_ssm_states,
|
| 157 |
+
seq_idx=seq_idx,
|
| 158 |
+
dt_softplus=True,
|
| 159 |
+
dt_limit=dt_limit,
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
# reshape input data into 2D tensor
|
| 163 |
+
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
| 164 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 165 |
+
rmsnorm_weight = rmsnorm_weight.contiguous()
|
| 166 |
+
if d_nonssm == 0:
|
| 167 |
+
out = None
|
| 168 |
+
else:
|
| 169 |
+
out01 = torch.empty(
|
| 170 |
+
(batch, seqlen, d_nonssm + dim),
|
| 171 |
+
dtype=x_rms.dtype,
|
| 172 |
+
device=x_rms.device,
|
| 173 |
+
)
|
| 174 |
+
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
| 175 |
+
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
| 176 |
+
out, _, rstd = _layer_norm_fwd(
|
| 177 |
+
x_rms,
|
| 178 |
+
rmsnorm_weight,
|
| 179 |
+
None,
|
| 180 |
+
rmsnorm_eps,
|
| 181 |
+
z_rms,
|
| 182 |
+
out=out,
|
| 183 |
+
group_size=dim // ngroups,
|
| 184 |
+
norm_before_gate=norm_before_gate,
|
| 185 |
+
is_rms_norm=True,
|
| 186 |
+
)
|
| 187 |
+
if d_nonssm == 0:
|
| 188 |
+
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
| 189 |
+
else:
|
| 190 |
+
out = out01
|
| 191 |
+
ctx.outproj_weight_dtype = (
|
| 192 |
+
outproj_weight.dtype if outproj_weight is not None else None
|
| 193 |
+
)
|
| 194 |
+
if outproj_weight is not None:
|
| 195 |
+
if torch.is_autocast_enabled():
|
| 196 |
+
dtype = torch.get_autocast_gpu_dtype()
|
| 197 |
+
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
| 198 |
+
outproj_bias = (
|
| 199 |
+
outproj_bias.to(dtype) if outproj_bias is not None else None
|
| 200 |
+
)
|
| 201 |
+
out = F.linear(out, outproj_weight, outproj_bias)
|
| 202 |
+
else:
|
| 203 |
+
assert outproj_bias is None
|
| 204 |
+
if out is not None and seq_idx is not None:
|
| 205 |
+
out = out * (seq_idx.unsqueeze(-1) >= 0).to(out.dtype)
|
| 206 |
+
ctx.save_for_backward(
|
| 207 |
+
zxbcdt,
|
| 208 |
+
conv1d_weight,
|
| 209 |
+
conv1d_bias,
|
| 210 |
+
out_x,
|
| 211 |
+
A,
|
| 212 |
+
D,
|
| 213 |
+
dt_bias,
|
| 214 |
+
initial_conv_states,
|
| 215 |
+
initial_ssm_states,
|
| 216 |
+
seq_idx,
|
| 217 |
+
rmsnorm_weight,
|
| 218 |
+
rstd,
|
| 219 |
+
outproj_weight,
|
| 220 |
+
outproj_bias,
|
| 221 |
+
)
|
| 222 |
+
ctx.dt_limit = dt_limit
|
| 223 |
+
ctx.return_final_states = return_final_states
|
| 224 |
+
ctx.activation = activation
|
| 225 |
+
ctx.rmsnorm_eps = rmsnorm_eps
|
| 226 |
+
ctx.norm_before_gate = norm_before_gate
|
| 227 |
+
ctx.chunk_size = chunk_size
|
| 228 |
+
ctx.headdim = headdim
|
| 229 |
+
ctx.ngroups = ngroups
|
| 230 |
+
return (
|
| 231 |
+
out
|
| 232 |
+
if not return_final_states
|
| 233 |
+
else (out, final_conv_states, final_ssm_states)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
@custom_bwd
|
| 238 |
+
def backward(ctx, dout, *args):
|
| 239 |
+
(
|
| 240 |
+
zxbcdt,
|
| 241 |
+
conv1d_weight,
|
| 242 |
+
conv1d_bias,
|
| 243 |
+
out,
|
| 244 |
+
A,
|
| 245 |
+
D,
|
| 246 |
+
dt_bias,
|
| 247 |
+
initial_conv_states,
|
| 248 |
+
initial_ssm_states,
|
| 249 |
+
seq_idx,
|
| 250 |
+
rmsnorm_weight,
|
| 251 |
+
rstd,
|
| 252 |
+
outproj_weight,
|
| 253 |
+
outproj_bias,
|
| 254 |
+
) = ctx.saved_tensors
|
| 255 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
| 256 |
+
headdim = ctx.headdim
|
| 257 |
+
nheads = D.shape[0]
|
| 258 |
+
dim = nheads * headdim
|
| 259 |
+
assert nheads % ctx.ngroups == 0
|
| 260 |
+
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
| 261 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
| 262 |
+
assert d_nonssm >= 0
|
| 263 |
+
recompute_output = outproj_weight is not None
|
| 264 |
+
if recompute_output:
|
| 265 |
+
out_recompute = torch.empty(
|
| 266 |
+
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
| 267 |
+
)
|
| 268 |
+
out0_recompute, out1_recompute = out_recompute.split(
|
| 269 |
+
[d_nonssm, dim], dim=-1
|
| 270 |
+
)
|
| 271 |
+
zx0, z, xBC, dt = torch.split(
|
| 272 |
+
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 273 |
+
)
|
| 274 |
+
# Recompute x, B, C
|
| 275 |
+
xBC_conv = rearrange(
|
| 276 |
+
causal_conv1d_fwd_function(
|
| 277 |
+
rearrange_and_update_stride(xBC, "b s d -> b d s"),
|
| 278 |
+
conv1d_weight,
|
| 279 |
+
conv1d_bias,
|
| 280 |
+
None,
|
| 281 |
+
initial_conv_states,
|
| 282 |
+
None,
|
| 283 |
+
ctx.activation in ["silu", "swish"],
|
| 284 |
+
),
|
| 285 |
+
"b d s -> b s d",
|
| 286 |
+
)
|
| 287 |
+
x, B, C = torch.split(
|
| 288 |
+
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 289 |
+
)
|
| 290 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 291 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 292 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 293 |
+
dzxbcdt = torch.empty_like(zxbcdt)
|
| 294 |
+
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
| 295 |
+
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 296 |
+
)
|
| 297 |
+
dxBC = torch.empty_like(xBC)
|
| 298 |
+
dx, dB, dC = torch.split(
|
| 299 |
+
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 300 |
+
)
|
| 301 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 302 |
+
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
| 303 |
+
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 304 |
+
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 305 |
+
if outproj_weight is not None:
|
| 306 |
+
dout_og = dout
|
| 307 |
+
dout = F.linear(dout, outproj_weight.t())
|
| 308 |
+
if d_nonssm > 0:
|
| 309 |
+
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
| 310 |
+
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
| 311 |
+
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
| 312 |
+
if rmsnorm_weight is None:
|
| 313 |
+
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
| 314 |
+
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_ssm_states, *rest = (
|
| 315 |
+
_mamba_chunk_scan_combined_bwd(
|
| 316 |
+
dout,
|
| 317 |
+
x,
|
| 318 |
+
dt,
|
| 319 |
+
A,
|
| 320 |
+
B,
|
| 321 |
+
C,
|
| 322 |
+
out,
|
| 323 |
+
ctx.chunk_size,
|
| 324 |
+
D=D,
|
| 325 |
+
z=z,
|
| 326 |
+
dt_bias=dt_bias,
|
| 327 |
+
initial_states=initial_ssm_states,
|
| 328 |
+
dfinal_states=dfinal_states,
|
| 329 |
+
seq_idx=seq_idx,
|
| 330 |
+
dt_softplus=True,
|
| 331 |
+
dt_limit=ctx.dt_limit,
|
| 332 |
+
dx=dx,
|
| 333 |
+
ddt=ddt_given,
|
| 334 |
+
dB=dB,
|
| 335 |
+
dC=dC,
|
| 336 |
+
dz=dz,
|
| 337 |
+
recompute_output=recompute_output,
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
out_for_linear = (
|
| 341 |
+
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
| 342 |
+
)
|
| 343 |
+
drmsnorm_weight = None
|
| 344 |
+
else:
|
| 345 |
+
batch = dout.shape[0]
|
| 346 |
+
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
| 347 |
+
dz = rearrange(dz, "b l d -> (b l) d")
|
| 348 |
+
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
| 349 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 350 |
+
out1_recompute = (
|
| 351 |
+
rearrange(out1_recompute, "b s d -> (b s) d")
|
| 352 |
+
if recompute_output
|
| 353 |
+
else None
|
| 354 |
+
)
|
| 355 |
+
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
| 356 |
+
dy_rms,
|
| 357 |
+
x_rms,
|
| 358 |
+
rmsnorm_weight,
|
| 359 |
+
None,
|
| 360 |
+
ctx.rmsnorm_eps,
|
| 361 |
+
None,
|
| 362 |
+
rstd,
|
| 363 |
+
z_rms,
|
| 364 |
+
group_size=dim // ctx.ngroups,
|
| 365 |
+
norm_before_gate=ctx.norm_before_gate,
|
| 366 |
+
is_rms_norm=True,
|
| 367 |
+
recompute_output=recompute_output,
|
| 368 |
+
dz=dz,
|
| 369 |
+
out=out1_recompute if recompute_output else None,
|
| 370 |
+
)
|
| 371 |
+
out_for_linear = out_recompute if recompute_output else None
|
| 372 |
+
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
| 373 |
+
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_ssm_states = (
|
| 374 |
+
_mamba_chunk_scan_combined_bwd(
|
| 375 |
+
dout,
|
| 376 |
+
x,
|
| 377 |
+
dt,
|
| 378 |
+
A,
|
| 379 |
+
B,
|
| 380 |
+
C,
|
| 381 |
+
out,
|
| 382 |
+
ctx.chunk_size,
|
| 383 |
+
D=D,
|
| 384 |
+
z=None,
|
| 385 |
+
dt_bias=dt_bias,
|
| 386 |
+
initial_states=initial_ssm_states,
|
| 387 |
+
dfinal_states=dfinal_states,
|
| 388 |
+
seq_idx=seq_idx,
|
| 389 |
+
dt_softplus=True,
|
| 390 |
+
dt_limit=ctx.dt_limit,
|
| 391 |
+
dx=dx,
|
| 392 |
+
ddt=ddt_given,
|
| 393 |
+
dB=dB,
|
| 394 |
+
dC=dC,
|
| 395 |
+
)
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if outproj_weight is not None:
|
| 399 |
+
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
| 400 |
+
doutproj_bias = (
|
| 401 |
+
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
doutproj_weight, doutproj_bias = None, None
|
| 405 |
+
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
| 406 |
+
dxBC_given_update, dweight, dbias, dinitial_conv_states, *_ = (
|
| 407 |
+
causal_conv1d_bwd_function(
|
| 408 |
+
rearrange_and_update_stride(xBC, "b s d -> b d s"),
|
| 409 |
+
conv1d_weight,
|
| 410 |
+
conv1d_bias,
|
| 411 |
+
rearrange(dxBC, "b s d -> b d s"),
|
| 412 |
+
# seq_idx,
|
| 413 |
+
seq_idx if initial_conv_states is None else None,
|
| 414 |
+
initial_conv_states,
|
| 415 |
+
None,
|
| 416 |
+
rearrange_and_update_stride(dxBC_given),
|
| 417 |
+
True,
|
| 418 |
+
ctx.activation in ["silu", "swish"],
|
| 419 |
+
)
|
| 420 |
+
)
|
| 421 |
+
if dxBC_given.stride() != dxBC_given_update.stride():
|
| 422 |
+
dxBC_given.copy_(dxBC_given_update)
|
| 423 |
+
else:
|
| 424 |
+
dxBC_given = dxBC_given_update
|
| 425 |
+
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
| 426 |
+
return (
|
| 427 |
+
dzxbcdt,
|
| 428 |
+
dweight,
|
| 429 |
+
dbias,
|
| 430 |
+
ddt_bias,
|
| 431 |
+
dA,
|
| 432 |
+
dD,
|
| 433 |
+
None,
|
| 434 |
+
dinitial_conv_states,
|
| 435 |
+
dinitial_ssm_states,
|
| 436 |
+
None,
|
| 437 |
+
None,
|
| 438 |
+
None,
|
| 439 |
+
None,
|
| 440 |
+
drmsnorm_weight,
|
| 441 |
+
None,
|
| 442 |
+
doutproj_weight,
|
| 443 |
+
doutproj_bias,
|
| 444 |
+
None,
|
| 445 |
+
None,
|
| 446 |
+
None,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def chunkable_mamba_split_conv1d_scan_combined(
|
| 451 |
+
zxbcdt,
|
| 452 |
+
conv1d_weight,
|
| 453 |
+
conv1d_bias,
|
| 454 |
+
dt_bias,
|
| 455 |
+
A,
|
| 456 |
+
D,
|
| 457 |
+
chunk_size,
|
| 458 |
+
initial_conv_states=None,
|
| 459 |
+
initial_ssm_states=None,
|
| 460 |
+
seq_idx=None,
|
| 461 |
+
dt_limit=(0.0, float("inf")),
|
| 462 |
+
return_final_states=False,
|
| 463 |
+
activation="silu",
|
| 464 |
+
rmsnorm_weight=None,
|
| 465 |
+
rmsnorm_eps=1e-6,
|
| 466 |
+
outproj_weight=None,
|
| 467 |
+
outproj_bias=None,
|
| 468 |
+
headdim=None,
|
| 469 |
+
ngroups=1,
|
| 470 |
+
norm_before_gate=True,
|
| 471 |
+
):
|
| 472 |
+
"""
|
| 473 |
+
Argument:
|
| 474 |
+
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 475 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 476 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 477 |
+
dt_bias: (nheads,)
|
| 478 |
+
A: (nheads)
|
| 479 |
+
D: (nheads, headdim) or (nheads,)
|
| 480 |
+
initial_states: (batch, nheads, headdim, dstate)
|
| 481 |
+
seq_idx: (batch, seqlen), int32
|
| 482 |
+
rmsnorm_weight: (dim,)
|
| 483 |
+
outproj_weight: (out_dim, dim)
|
| 484 |
+
outproj_bias: (out_dim,)
|
| 485 |
+
headdim: if D is 1D, headdim must be passed in
|
| 486 |
+
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 487 |
+
Return:
|
| 488 |
+
out: (batch, seqlen, dim)
|
| 489 |
+
"""
|
| 490 |
+
return ChunkableMambaSplitConv1dScanCombinedFn.apply(
|
| 491 |
+
zxbcdt,
|
| 492 |
+
conv1d_weight,
|
| 493 |
+
conv1d_bias,
|
| 494 |
+
dt_bias,
|
| 495 |
+
A,
|
| 496 |
+
D,
|
| 497 |
+
chunk_size,
|
| 498 |
+
initial_conv_states,
|
| 499 |
+
initial_ssm_states,
|
| 500 |
+
seq_idx,
|
| 501 |
+
dt_limit,
|
| 502 |
+
return_final_states,
|
| 503 |
+
activation,
|
| 504 |
+
rmsnorm_weight,
|
| 505 |
+
rmsnorm_eps,
|
| 506 |
+
outproj_weight,
|
| 507 |
+
outproj_bias,
|
| 508 |
+
headdim,
|
| 509 |
+
ngroups,
|
| 510 |
+
norm_before_gate,
|
| 511 |
+
)
|
configuration_chunkable_mamba2.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ChunkableMamba2Config(Mamba2Config):
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
*args,
|
| 8 |
+
use_mem_eff_path: bool = True,
|
| 9 |
+
**kwargs,
|
| 10 |
+
):
|
| 11 |
+
super().__init__(*args, **kwargs)
|
| 12 |
+
self.use_mem_eff_path = use_mem_eff_path
|
modeling_chunkable_mamba2.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_chunkable_mamba2 import ChunkableMamba2Config
|
| 2 |
+
from transformers.cache_utils import Cache, is_torchdynamo_compiling
|
| 3 |
+
from transformers.models.mamba2.modeling_mamba2 import (
|
| 4 |
+
Mamba2Block,
|
| 5 |
+
Mamba2Mixer,
|
| 6 |
+
Mamba2Model,
|
| 7 |
+
Mamba2RMSNorm,
|
| 8 |
+
apply_mask_to_padding_states,
|
| 9 |
+
)
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
mamba_split_conv1d_scan_combined = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChunkableMamba2Mixer(Mamba2Mixer):
|
| 18 |
+
def __init__(self, config: ChunkableMamba2Config, layer_idx: int):
|
| 19 |
+
super().__init__(config, layer_idx)
|
| 20 |
+
self.use_mem_eff_path = config.use_mem_eff_path
|
| 21 |
+
|
| 22 |
+
global mamba_split_conv1d_scan_combined
|
| 23 |
+
if self.use_mem_eff_path and mamba_split_conv1d_scan_combined is None:
|
| 24 |
+
from .chunkable_ssd_combined import chunkable_mamba_split_conv1d_scan_combined
|
| 25 |
+
mamba_split_conv1d_scan_combined = chunkable_mamba_split_conv1d_scan_combined
|
| 26 |
+
|
| 27 |
+
def cuda_kernels_forward(
|
| 28 |
+
self,
|
| 29 |
+
hidden_states: torch.Tensor,
|
| 30 |
+
cache_params: Cache | None = None,
|
| 31 |
+
attention_mask: torch.Tensor | None = None,
|
| 32 |
+
):
|
| 33 |
+
if (
|
| 34 |
+
cache_params is not None
|
| 35 |
+
and cache_params.has_previous_state(self.layer_idx)
|
| 36 |
+
) and not self.use_mem_eff_path:
|
| 37 |
+
return super().cuda_kernels_forward(
|
| 38 |
+
hidden_states=hidden_states,
|
| 39 |
+
cache_params=cache_params,
|
| 40 |
+
attention_mask=attention_mask,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# 1. Gated MLP's linear projection
|
| 44 |
+
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask[:, -hidden_states.size(1):])
|
| 45 |
+
projected_states = self.in_proj(hidden_states)
|
| 46 |
+
|
| 47 |
+
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
|
| 48 |
+
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
|
| 49 |
+
|
| 50 |
+
seq_idx = (
|
| 51 |
+
(attention_mask[:, -hidden_states.size(1) :] - 1).to(torch.int32)
|
| 52 |
+
if attention_mask is not None
|
| 53 |
+
else None
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 2-4. Fused kernel for conv1d, SSM, and the final projection
|
| 57 |
+
out = mamba_split_conv1d_scan_combined(
|
| 58 |
+
projected_states,
|
| 59 |
+
self.conv1d.weight.squeeze(1),
|
| 60 |
+
self.conv1d.bias,
|
| 61 |
+
self.dt_bias,
|
| 62 |
+
A,
|
| 63 |
+
D=self.D,
|
| 64 |
+
chunk_size=self.chunk_size,
|
| 65 |
+
seq_idx=seq_idx,
|
| 66 |
+
activation=self.activation,
|
| 67 |
+
rmsnorm_weight=self.norm.weight,
|
| 68 |
+
rmsnorm_eps=self.norm.variance_epsilon,
|
| 69 |
+
outproj_weight=self.out_proj.weight,
|
| 70 |
+
outproj_bias=self.out_proj.bias,
|
| 71 |
+
headdim=self.head_dim,
|
| 72 |
+
ngroups=self.n_groups,
|
| 73 |
+
norm_before_gate=False,
|
| 74 |
+
initial_conv_states=cache_params.layers[self.layer_idx].conv_states
|
| 75 |
+
if cache_params is not None
|
| 76 |
+
else None,
|
| 77 |
+
initial_ssm_states=cache_params.layers[self.layer_idx].recurrent_states
|
| 78 |
+
if cache_params is not None
|
| 79 |
+
else None,
|
| 80 |
+
return_final_states=cache_params is not None,
|
| 81 |
+
**dt_limit_kwargs,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if cache_params is not None:
|
| 85 |
+
out, conv_states, ssm_state = out
|
| 86 |
+
cache_params.layers[self.layer_idx].has_previous_state = False
|
| 87 |
+
cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx)
|
| 88 |
+
cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx)
|
| 89 |
+
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
hidden_states: torch.Tensor,
|
| 95 |
+
cache_params: Cache | None = None,
|
| 96 |
+
attention_mask: torch.Tensor | None = None,
|
| 97 |
+
):
|
| 98 |
+
if "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
|
| 99 |
+
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
|
| 100 |
+
return self.torch_forward(hidden_states, cache_params, attention_mask)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ChunkableMamba2Block(Mamba2Block):
|
| 104 |
+
def __init__(self, config, layer_idx):
|
| 105 |
+
super(Mamba2Block, self).__init__()
|
| 106 |
+
self.config = config
|
| 107 |
+
self.layer_idx = layer_idx
|
| 108 |
+
self.residual_in_fp32 = config.residual_in_fp32
|
| 109 |
+
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 110 |
+
self.mixer = ChunkableMamba2Mixer(config, layer_idx=layer_idx)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ChunkableMamba2Model(Mamba2Model):
|
| 114 |
+
config_class = ChunkableMamba2Config
|
| 115 |
+
|
| 116 |
+
def __init__(self, config):
|
| 117 |
+
super(Mamba2Model, self).__init__(config)
|
| 118 |
+
|
| 119 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 120 |
+
self.layers = nn.ModuleList(
|
| 121 |
+
[
|
| 122 |
+
ChunkableMamba2Block(config, layer_idx=idx)
|
| 123 |
+
for idx in range(config.num_hidden_layers)
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self.gradient_checkpointing = False
|
| 128 |
+
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 129 |
+
# Initialize weights and apply final processing
|
| 130 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 131 |
+
self.post_init()
|