Transformers
mamba2
vertical-chunking
grantner commited on
Commit
fb2906b
·
verified ·
1 Parent(s): aae9874

feat: chunkable mamba2 model code

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ tags:
5
+ - transformers
6
+ - mamba2
7
+ - vertical-chunking
8
+ ---
9
+
10
+ # chunkable-mamba2
11
+
12
+ Custom [Mamba2](https://arxiv.org/abs/2405.21060) model and configuration classes for 🤗 Transformers that add support for **vertically chunked inference**, which processes input sequences in fixed-size vertical chunks through all model layers with constant memory usage, regardless of sequence length.
13
+
14
+ ## What this repository provides
15
+
16
+ - **`ChunkableMamba2Config`:** extends `Mamba2Config` with a `use_mem_eff_path` option for the memory-efficient CUDA kernel path.
17
+ - **`ChunkableMamba2Model`:** extends `Mamba2Model` with a chunkable mixer and cache that correctly propagate the recurrent states across vertical chunks (simultaneous `seq_idx` + `initial_states` support).
18
+ - **`chunkable_mamba_split_conv1d_scan_combined`:** modified `mamba_split_conv1d_scan_combined` kernel wrapper that passes cache parameters through the SSD scan so that conv and SSM states are properly initialized and exported during chunked inference.
19
+
20
+ ## Usage
21
+
22
+ This repository is designed to be referenced directly from Hugging Face model configs via `auto_map`, so that models can be loaded with `trust_remote_code=True` without any local installation:
23
+
24
+ ```json
25
+ "auto_map": {
26
+ "AutoConfig": "dynatrace-oss/chunkable-mamba2--configuration_chunkable_mamba2.ChunkableMamba2Config",
27
+ "AutoModel": "dynatrace-oss/chunkable-mamba2--modeling_chunkable_mamba2.ChunkableMamba2Model"
28
+ }
29
+ ```
30
+
31
+ ## Models
32
+
33
+ This code was created for the following embedding models:
34
+
35
+ - [dynatrace-oss/llama-embed-mamba2-7b](https://huggingface.co/dynatrace-oss/llama-embed-mamba2-7b)
36
+ - [dynatrace-oss/llama-embed-mamba2-1.3b](https://huggingface.co/dynatrace-oss/llama-embed-mamba2-1.3b)
37
+
38
+ ## Requirements
39
+
40
+ > [!IMPORTANT]
41
+ > Requires `transformers>=5.5.0` due to a breaking change to the cache of Mamba2 introduced in `v5.5.0` ([transformers#44950](https://github.com/huggingface/transformers/pull/44950)).
42
+
43
+ ```bash
44
+ pip install transformers kernels einops
45
+ ```
46
+
47
+ ## Open Source Integration Roadmap
48
+
49
+ Our goal is to integrate all necessary changes to simplify the adoption of vertically chunked inference for other models:
50
+
51
+ > [!Note]
52
+ > ⚪ Planned | 🟡 In Progress | 🟢 Integrated
53
+
54
+ - ⚪ **causal-conv1d:** Enable simultaneous `seq_idx` + `initial_states` (required for recurrent processing of chunks with left padding)
55
+ - ⚪ **mamba-ssm:** Use `seq_idx` + `initial_states` in `mamba_split_conv1d_scan_combined` and export final states
56
+ - ⚪ **kernels-community:** Propagate changes in `causal-conv1d` and `mamba-ssm` to their kernel hub equivalents in the `kernels-community` repositories
57
+ - ⚪ **transformers:** Use updated `mamba_split_conv1d_scan_combined` with cache params during inference (currently only used during training, not configurable, problems with left padding)
58
+
59
+ *This list will be updated as integration progresses.*
60
+
61
+ ## License
62
+
63
+ Apache-2.0
chunkable_ssd_combined.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this"""
4
+
5
+ from packaging import version
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+
12
+ from einops import rearrange
13
+
14
+ from transformers.integrations.hub_kernels import get_kernel
15
+
16
+ # Fixed revisions because kernels after 2026-04-14 do not expose the functions we need anymore.
17
+ causal_conv1d = get_kernel("kernels-community/causal-conv1d", revision="dc7072f0e9d799b247a2517a909ebb209d50bea0")
18
+ mamba_ssm = get_kernel("kernels-community/mamba-ssm", revision="00b2ecd499379f9bcf969b6796e53bc867f4ad38")
19
+
20
+ causal_conv1d_fwd_function = causal_conv1d.cpp_functions.causal_conv1d_fwd_function
21
+ causal_conv1d_bwd_function = causal_conv1d.cpp_functions.causal_conv1d_bwd_function
22
+
23
+ custom_fwd = mamba_ssm.utils.torch.custom_fwd
24
+ custom_bwd = mamba_ssm.utils.torch.custom_bwd
25
+ _layer_norm_fwd = mamba_ssm.ops.triton.layernorm_gated._layer_norm_fwd
26
+ _layer_norm_bwd = mamba_ssm.ops.triton.layernorm_gated._layer_norm_bwd
27
+ _swiglu_fwd = mamba_ssm.ops.triton.k_activations._swiglu_fwd
28
+ _swiglu_bwd = mamba_ssm.ops.triton.k_activations._swiglu_bwd
29
+ rearrange_and_update_stride = mamba_ssm.ops.triton.ssd_combined.rearrange_and_update_stride
30
+ _mamba_chunk_scan_combined_fwd = mamba_ssm.ops.triton.ssd_combined._mamba_chunk_scan_combined_fwd
31
+ _mamba_chunk_scan_combined_bwd = mamba_ssm.ops.triton.ssd_combined._mamba_chunk_scan_combined_bwd
32
+
33
+
34
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
35
+
36
+
37
+ class ChunkableMambaSplitConv1dScanCombinedFn(torch.autograd.Function):
38
+ @staticmethod
39
+ @custom_fwd
40
+ def forward(
41
+ ctx,
42
+ zxbcdt,
43
+ conv1d_weight,
44
+ conv1d_bias,
45
+ dt_bias,
46
+ A,
47
+ D,
48
+ chunk_size,
49
+ initial_conv_states=None,
50
+ initial_ssm_states=None,
51
+ seq_idx=None,
52
+ dt_limit=(0.0, float("inf")),
53
+ return_final_states=False,
54
+ activation="silu",
55
+ rmsnorm_weight=None,
56
+ rmsnorm_eps=1e-6,
57
+ outproj_weight=None,
58
+ outproj_bias=None,
59
+ headdim=None,
60
+ ngroups=1,
61
+ norm_before_gate=True,
62
+ ):
63
+ assert activation in [None, "silu", "swish"]
64
+ if D.dim() == 1:
65
+ assert headdim is not None
66
+ (nheads,) = D.shape
67
+ else:
68
+ nheads, headdim = D.shape
69
+ batch, seqlen, _ = zxbcdt.shape
70
+ dim = nheads * headdim
71
+ assert nheads % ngroups == 0
72
+ dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
73
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
74
+ assert d_nonssm >= 0
75
+ assert zxbcdt.shape == (
76
+ batch,
77
+ seqlen,
78
+ 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
79
+ )
80
+ assert dt_bias.shape == (nheads,)
81
+ assert A.shape == (nheads,)
82
+ zx0, z, xBC, dt = torch.split(
83
+ zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
84
+ )
85
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
86
+ final_conv_states = (
87
+ torch.empty(
88
+ (batch, conv1d_weight.shape[1] - 1, dim + ngroups * dstate * 2),
89
+ device=xBC.device,
90
+ dtype=xBC.dtype,
91
+ ).transpose(1, 2)
92
+ if return_final_states
93
+ else None
94
+ )
95
+ # Workaround because causal_conv1d_fwd_function currently does not support seq_idx when initial_conv_states is not None.
96
+ # Additionally, there is a bug in causal_conv1d_fwd_function when seq_idx is used causing illegal memory access:
97
+ # - Issue: https://github.com/Dao-AILab/causal-conv1d/issues/67
98
+ # - PR: https://github.com/Dao-AILab/causal-conv1d/pull/101
99
+ if seq_idx is not None and initial_conv_states is not None:
100
+ xBC = xBC * (seq_idx.unsqueeze(-1) >= 0).to(xBC.dtype)
101
+ xBC_conv = rearrange(
102
+ causal_conv1d_fwd_function(
103
+ rearrange_and_update_stride(xBC, "b s d -> b d s"),
104
+ conv1d_weight,
105
+ conv1d_bias,
106
+ None,
107
+ initial_conv_states,
108
+ final_conv_states,
109
+ activation in ["silu", "swish"],
110
+ ),
111
+ "b d s -> b s d",
112
+ )
113
+ if seq_idx is not None and initial_conv_states is not None:
114
+ xBC_conv = xBC_conv * (seq_idx.unsqueeze(-1) >= 0).to(xBC_conv.dtype)
115
+ x, B, C = torch.split(
116
+ xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
117
+ )
118
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
119
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
120
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
121
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
122
+ if rmsnorm_weight is None:
123
+ out, out_x, dt_out, dA_cumsum, states, final_ssm_states = (
124
+ _mamba_chunk_scan_combined_fwd(
125
+ x,
126
+ dt,
127
+ A,
128
+ B,
129
+ C,
130
+ chunk_size=chunk_size,
131
+ D=D,
132
+ z=z,
133
+ dt_bias=dt_bias,
134
+ initial_states=initial_ssm_states,
135
+ seq_idx=seq_idx,
136
+ dt_softplus=True,
137
+ dt_limit=dt_limit,
138
+ )
139
+ )
140
+ out = rearrange(out, "b s h p -> b s (h p)")
141
+ rstd = None
142
+ if d_nonssm > 0:
143
+ out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
144
+ else:
145
+ out_x, _, dt_out, dA_cumsum, states, final_ssm_states = (
146
+ _mamba_chunk_scan_combined_fwd(
147
+ x,
148
+ dt,
149
+ A,
150
+ B,
151
+ C,
152
+ chunk_size=chunk_size,
153
+ D=D,
154
+ z=None,
155
+ dt_bias=dt_bias,
156
+ initial_states=initial_ssm_states,
157
+ seq_idx=seq_idx,
158
+ dt_softplus=True,
159
+ dt_limit=dt_limit,
160
+ )
161
+ )
162
+ # reshape input data into 2D tensor
163
+ x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
164
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
165
+ rmsnorm_weight = rmsnorm_weight.contiguous()
166
+ if d_nonssm == 0:
167
+ out = None
168
+ else:
169
+ out01 = torch.empty(
170
+ (batch, seqlen, d_nonssm + dim),
171
+ dtype=x_rms.dtype,
172
+ device=x_rms.device,
173
+ )
174
+ out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
175
+ _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
176
+ out, _, rstd = _layer_norm_fwd(
177
+ x_rms,
178
+ rmsnorm_weight,
179
+ None,
180
+ rmsnorm_eps,
181
+ z_rms,
182
+ out=out,
183
+ group_size=dim // ngroups,
184
+ norm_before_gate=norm_before_gate,
185
+ is_rms_norm=True,
186
+ )
187
+ if d_nonssm == 0:
188
+ out = rearrange(out, "(b s) d -> b s d", b=batch)
189
+ else:
190
+ out = out01
191
+ ctx.outproj_weight_dtype = (
192
+ outproj_weight.dtype if outproj_weight is not None else None
193
+ )
194
+ if outproj_weight is not None:
195
+ if torch.is_autocast_enabled():
196
+ dtype = torch.get_autocast_gpu_dtype()
197
+ out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
198
+ outproj_bias = (
199
+ outproj_bias.to(dtype) if outproj_bias is not None else None
200
+ )
201
+ out = F.linear(out, outproj_weight, outproj_bias)
202
+ else:
203
+ assert outproj_bias is None
204
+ if out is not None and seq_idx is not None:
205
+ out = out * (seq_idx.unsqueeze(-1) >= 0).to(out.dtype)
206
+ ctx.save_for_backward(
207
+ zxbcdt,
208
+ conv1d_weight,
209
+ conv1d_bias,
210
+ out_x,
211
+ A,
212
+ D,
213
+ dt_bias,
214
+ initial_conv_states,
215
+ initial_ssm_states,
216
+ seq_idx,
217
+ rmsnorm_weight,
218
+ rstd,
219
+ outproj_weight,
220
+ outproj_bias,
221
+ )
222
+ ctx.dt_limit = dt_limit
223
+ ctx.return_final_states = return_final_states
224
+ ctx.activation = activation
225
+ ctx.rmsnorm_eps = rmsnorm_eps
226
+ ctx.norm_before_gate = norm_before_gate
227
+ ctx.chunk_size = chunk_size
228
+ ctx.headdim = headdim
229
+ ctx.ngroups = ngroups
230
+ return (
231
+ out
232
+ if not return_final_states
233
+ else (out, final_conv_states, final_ssm_states)
234
+ )
235
+
236
+ @staticmethod
237
+ @custom_bwd
238
+ def backward(ctx, dout, *args):
239
+ (
240
+ zxbcdt,
241
+ conv1d_weight,
242
+ conv1d_bias,
243
+ out,
244
+ A,
245
+ D,
246
+ dt_bias,
247
+ initial_conv_states,
248
+ initial_ssm_states,
249
+ seq_idx,
250
+ rmsnorm_weight,
251
+ rstd,
252
+ outproj_weight,
253
+ outproj_bias,
254
+ ) = ctx.saved_tensors
255
+ dfinal_states = args[0] if ctx.return_final_states else None
256
+ headdim = ctx.headdim
257
+ nheads = D.shape[0]
258
+ dim = nheads * headdim
259
+ assert nheads % ctx.ngroups == 0
260
+ dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
261
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
262
+ assert d_nonssm >= 0
263
+ recompute_output = outproj_weight is not None
264
+ if recompute_output:
265
+ out_recompute = torch.empty(
266
+ *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
267
+ )
268
+ out0_recompute, out1_recompute = out_recompute.split(
269
+ [d_nonssm, dim], dim=-1
270
+ )
271
+ zx0, z, xBC, dt = torch.split(
272
+ zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
273
+ )
274
+ # Recompute x, B, C
275
+ xBC_conv = rearrange(
276
+ causal_conv1d_fwd_function(
277
+ rearrange_and_update_stride(xBC, "b s d -> b d s"),
278
+ conv1d_weight,
279
+ conv1d_bias,
280
+ None,
281
+ initial_conv_states,
282
+ None,
283
+ ctx.activation in ["silu", "swish"],
284
+ ),
285
+ "b d s -> b s d",
286
+ )
287
+ x, B, C = torch.split(
288
+ xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
289
+ )
290
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
291
+ B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
292
+ C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
293
+ dzxbcdt = torch.empty_like(zxbcdt)
294
+ dzx0, dz, dxBC_given, ddt_given = torch.split(
295
+ dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
296
+ )
297
+ dxBC = torch.empty_like(xBC)
298
+ dx, dB, dC = torch.split(
299
+ dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
300
+ )
301
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
302
+ dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
303
+ dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
304
+ dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
305
+ if outproj_weight is not None:
306
+ dout_og = dout
307
+ dout = F.linear(dout, outproj_weight.t())
308
+ if d_nonssm > 0:
309
+ dout0, dout = dout.split([d_nonssm, dim], dim=-1)
310
+ _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
311
+ dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
312
+ if rmsnorm_weight is None:
313
+ dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
314
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_ssm_states, *rest = (
315
+ _mamba_chunk_scan_combined_bwd(
316
+ dout,
317
+ x,
318
+ dt,
319
+ A,
320
+ B,
321
+ C,
322
+ out,
323
+ ctx.chunk_size,
324
+ D=D,
325
+ z=z,
326
+ dt_bias=dt_bias,
327
+ initial_states=initial_ssm_states,
328
+ dfinal_states=dfinal_states,
329
+ seq_idx=seq_idx,
330
+ dt_softplus=True,
331
+ dt_limit=ctx.dt_limit,
332
+ dx=dx,
333
+ ddt=ddt_given,
334
+ dB=dB,
335
+ dC=dC,
336
+ dz=dz,
337
+ recompute_output=recompute_output,
338
+ )
339
+ )
340
+ out_for_linear = (
341
+ rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
342
+ )
343
+ drmsnorm_weight = None
344
+ else:
345
+ batch = dout.shape[0]
346
+ dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
347
+ dz = rearrange(dz, "b l d -> (b l) d")
348
+ x_rms = rearrange(out, "b s h p -> (b s) (h p)")
349
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
350
+ out1_recompute = (
351
+ rearrange(out1_recompute, "b s d -> (b s) d")
352
+ if recompute_output
353
+ else None
354
+ )
355
+ dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
356
+ dy_rms,
357
+ x_rms,
358
+ rmsnorm_weight,
359
+ None,
360
+ ctx.rmsnorm_eps,
361
+ None,
362
+ rstd,
363
+ z_rms,
364
+ group_size=dim // ctx.ngroups,
365
+ norm_before_gate=ctx.norm_before_gate,
366
+ is_rms_norm=True,
367
+ recompute_output=recompute_output,
368
+ dz=dz,
369
+ out=out1_recompute if recompute_output else None,
370
+ )
371
+ out_for_linear = out_recompute if recompute_output else None
372
+ dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
373
+ dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_ssm_states = (
374
+ _mamba_chunk_scan_combined_bwd(
375
+ dout,
376
+ x,
377
+ dt,
378
+ A,
379
+ B,
380
+ C,
381
+ out,
382
+ ctx.chunk_size,
383
+ D=D,
384
+ z=None,
385
+ dt_bias=dt_bias,
386
+ initial_states=initial_ssm_states,
387
+ dfinal_states=dfinal_states,
388
+ seq_idx=seq_idx,
389
+ dt_softplus=True,
390
+ dt_limit=ctx.dt_limit,
391
+ dx=dx,
392
+ ddt=ddt_given,
393
+ dB=dB,
394
+ dC=dC,
395
+ )
396
+ )
397
+
398
+ if outproj_weight is not None:
399
+ doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
400
+ doutproj_bias = (
401
+ dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
402
+ )
403
+ else:
404
+ doutproj_weight, doutproj_bias = None, None
405
+ dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
406
+ dxBC_given_update, dweight, dbias, dinitial_conv_states, *_ = (
407
+ causal_conv1d_bwd_function(
408
+ rearrange_and_update_stride(xBC, "b s d -> b d s"),
409
+ conv1d_weight,
410
+ conv1d_bias,
411
+ rearrange(dxBC, "b s d -> b d s"),
412
+ # seq_idx,
413
+ seq_idx if initial_conv_states is None else None,
414
+ initial_conv_states,
415
+ None,
416
+ rearrange_and_update_stride(dxBC_given),
417
+ True,
418
+ ctx.activation in ["silu", "swish"],
419
+ )
420
+ )
421
+ if dxBC_given.stride() != dxBC_given_update.stride():
422
+ dxBC_given.copy_(dxBC_given_update)
423
+ else:
424
+ dxBC_given = dxBC_given_update
425
+ dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
426
+ return (
427
+ dzxbcdt,
428
+ dweight,
429
+ dbias,
430
+ ddt_bias,
431
+ dA,
432
+ dD,
433
+ None,
434
+ dinitial_conv_states,
435
+ dinitial_ssm_states,
436
+ None,
437
+ None,
438
+ None,
439
+ None,
440
+ drmsnorm_weight,
441
+ None,
442
+ doutproj_weight,
443
+ doutproj_bias,
444
+ None,
445
+ None,
446
+ None,
447
+ )
448
+
449
+
450
+ def chunkable_mamba_split_conv1d_scan_combined(
451
+ zxbcdt,
452
+ conv1d_weight,
453
+ conv1d_bias,
454
+ dt_bias,
455
+ A,
456
+ D,
457
+ chunk_size,
458
+ initial_conv_states=None,
459
+ initial_ssm_states=None,
460
+ seq_idx=None,
461
+ dt_limit=(0.0, float("inf")),
462
+ return_final_states=False,
463
+ activation="silu",
464
+ rmsnorm_weight=None,
465
+ rmsnorm_eps=1e-6,
466
+ outproj_weight=None,
467
+ outproj_bias=None,
468
+ headdim=None,
469
+ ngroups=1,
470
+ norm_before_gate=True,
471
+ ):
472
+ """
473
+ Argument:
474
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
475
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
476
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
477
+ dt_bias: (nheads,)
478
+ A: (nheads)
479
+ D: (nheads, headdim) or (nheads,)
480
+ initial_states: (batch, nheads, headdim, dstate)
481
+ seq_idx: (batch, seqlen), int32
482
+ rmsnorm_weight: (dim,)
483
+ outproj_weight: (out_dim, dim)
484
+ outproj_bias: (out_dim,)
485
+ headdim: if D is 1D, headdim must be passed in
486
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
487
+ Return:
488
+ out: (batch, seqlen, dim)
489
+ """
490
+ return ChunkableMambaSplitConv1dScanCombinedFn.apply(
491
+ zxbcdt,
492
+ conv1d_weight,
493
+ conv1d_bias,
494
+ dt_bias,
495
+ A,
496
+ D,
497
+ chunk_size,
498
+ initial_conv_states,
499
+ initial_ssm_states,
500
+ seq_idx,
501
+ dt_limit,
502
+ return_final_states,
503
+ activation,
504
+ rmsnorm_weight,
505
+ rmsnorm_eps,
506
+ outproj_weight,
507
+ outproj_bias,
508
+ headdim,
509
+ ngroups,
510
+ norm_before_gate,
511
+ )
configuration_chunkable_mamba2.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
2
+
3
+
4
+ class ChunkableMamba2Config(Mamba2Config):
5
+ def __init__(
6
+ self,
7
+ *args,
8
+ use_mem_eff_path: bool = True,
9
+ **kwargs,
10
+ ):
11
+ super().__init__(*args, **kwargs)
12
+ self.use_mem_eff_path = use_mem_eff_path
modeling_chunkable_mamba2.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_chunkable_mamba2 import ChunkableMamba2Config
2
+ from transformers.cache_utils import Cache, is_torchdynamo_compiling
3
+ from transformers.models.mamba2.modeling_mamba2 import (
4
+ Mamba2Block,
5
+ Mamba2Mixer,
6
+ Mamba2Model,
7
+ Mamba2RMSNorm,
8
+ apply_mask_to_padding_states,
9
+ )
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ mamba_split_conv1d_scan_combined = None
15
+
16
+
17
+ class ChunkableMamba2Mixer(Mamba2Mixer):
18
+ def __init__(self, config: ChunkableMamba2Config, layer_idx: int):
19
+ super().__init__(config, layer_idx)
20
+ self.use_mem_eff_path = config.use_mem_eff_path
21
+
22
+ global mamba_split_conv1d_scan_combined
23
+ if self.use_mem_eff_path and mamba_split_conv1d_scan_combined is None:
24
+ from .chunkable_ssd_combined import chunkable_mamba_split_conv1d_scan_combined
25
+ mamba_split_conv1d_scan_combined = chunkable_mamba_split_conv1d_scan_combined
26
+
27
+ def cuda_kernels_forward(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ cache_params: Cache | None = None,
31
+ attention_mask: torch.Tensor | None = None,
32
+ ):
33
+ if (
34
+ cache_params is not None
35
+ and cache_params.has_previous_state(self.layer_idx)
36
+ ) and not self.use_mem_eff_path:
37
+ return super().cuda_kernels_forward(
38
+ hidden_states=hidden_states,
39
+ cache_params=cache_params,
40
+ attention_mask=attention_mask,
41
+ )
42
+
43
+ # 1. Gated MLP's linear projection
44
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask[:, -hidden_states.size(1):])
45
+ projected_states = self.in_proj(hidden_states)
46
+
47
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
48
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
49
+
50
+ seq_idx = (
51
+ (attention_mask[:, -hidden_states.size(1) :] - 1).to(torch.int32)
52
+ if attention_mask is not None
53
+ else None
54
+ )
55
+
56
+ # 2-4. Fused kernel for conv1d, SSM, and the final projection
57
+ out = mamba_split_conv1d_scan_combined(
58
+ projected_states,
59
+ self.conv1d.weight.squeeze(1),
60
+ self.conv1d.bias,
61
+ self.dt_bias,
62
+ A,
63
+ D=self.D,
64
+ chunk_size=self.chunk_size,
65
+ seq_idx=seq_idx,
66
+ activation=self.activation,
67
+ rmsnorm_weight=self.norm.weight,
68
+ rmsnorm_eps=self.norm.variance_epsilon,
69
+ outproj_weight=self.out_proj.weight,
70
+ outproj_bias=self.out_proj.bias,
71
+ headdim=self.head_dim,
72
+ ngroups=self.n_groups,
73
+ norm_before_gate=False,
74
+ initial_conv_states=cache_params.layers[self.layer_idx].conv_states
75
+ if cache_params is not None
76
+ else None,
77
+ initial_ssm_states=cache_params.layers[self.layer_idx].recurrent_states
78
+ if cache_params is not None
79
+ else None,
80
+ return_final_states=cache_params is not None,
81
+ **dt_limit_kwargs,
82
+ )
83
+
84
+ if cache_params is not None:
85
+ out, conv_states, ssm_state = out
86
+ cache_params.layers[self.layer_idx].has_previous_state = False
87
+ cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx)
88
+ cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx)
89
+
90
+ return out
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ cache_params: Cache | None = None,
96
+ attention_mask: torch.Tensor | None = None,
97
+ ):
98
+ if "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
99
+ return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
100
+ return self.torch_forward(hidden_states, cache_params, attention_mask)
101
+
102
+
103
+ class ChunkableMamba2Block(Mamba2Block):
104
+ def __init__(self, config, layer_idx):
105
+ super(Mamba2Block, self).__init__()
106
+ self.config = config
107
+ self.layer_idx = layer_idx
108
+ self.residual_in_fp32 = config.residual_in_fp32
109
+ self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
110
+ self.mixer = ChunkableMamba2Mixer(config, layer_idx=layer_idx)
111
+
112
+
113
+ class ChunkableMamba2Model(Mamba2Model):
114
+ config_class = ChunkableMamba2Config
115
+
116
+ def __init__(self, config):
117
+ super(Mamba2Model, self).__init__(config)
118
+
119
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
120
+ self.layers = nn.ModuleList(
121
+ [
122
+ ChunkableMamba2Block(config, layer_idx=idx)
123
+ for idx in range(config.num_hidden_layers)
124
+ ]
125
+ )
126
+
127
+ self.gradient_checkpointing = False
128
+ self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
129
+ # Initialize weights and apply final processing
130
+ self._register_load_state_dict_pre_hook(self.load_hook)
131
+ self.post_init()