fmgreco commited on
Commit
199170e
·
1 Parent(s): 346e086

Add ROCm dual GEMM, MXFP4, mask compaction, group GEMM

Browse files
AMDmi300asubmission.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ import os
5
+
6
+ # 1. HARDWARE DIAGNOSTICS
7
+ def check_environment():
8
+ print(f"--- Environment Check ---")
9
+ cuda_avail = torch.cuda.is_available()
10
+ print(f"Is CUDA/ROCm available? {cuda_avail}")
11
+
12
+ if cuda_avail:
13
+ device_name = torch.cuda.get_device_name(0)
14
+ print(f"GPU Detected: {device_name}")
15
+ prop = torch.cuda.get_device_properties(0)
16
+ if hasattr(prop, 'major'):
17
+ print(f"Compute Capability: {prop.major}.{prop.minor}")
18
+ # Optimization: Use persistent kernel constants for specific GPUs
19
+ os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
20
+ else:
21
+ print("No NVIDIA/AMD GPU detected. Triton kernels will not run on this hardware.")
22
+ print(f"-------------------------\n")
23
+
24
+ check_environment()
25
+
26
+ # 2. OPTIMIZED DUAL GEMM KERNEL
27
+ @triton.autotune(
28
+ configs=[
29
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
30
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
31
+ ],
32
+ key=['M', 'N', 'K'],
33
+ )
34
+ @triton.jit
35
+ def dual_gemm_kernel(
36
+ a_ptr, b1_ptr, b2_ptr, c_ptr,
37
+ sfa_ptr, sfb1_ptr, sfb2_ptr,
38
+ M, N, K, L,
39
+ stride_am, stride_ak, stride_al,
40
+ stride_bn, stride_bk, stride_bl,
41
+ stride_cm, stride_cn, stride_cl,
42
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
43
+ GROUP_SIZE_M: tl.constexpr,
44
+ ):
45
+ # Program ID & Work distribution
46
+ pid = tl.program_id(0)
47
+ num_pid_m = tl.cdiv(M, BLOCK_M)
48
+ num_pid_n = tl.cdiv(N, BLOCK_N)
49
+
50
+ # Persistent Grid Loop (Iterates over batches L and tiles)
51
+ total_tiles = num_pid_m * num_pid_n * L
52
+ for tile_idx in tl.range(pid, total_tiles, tl.num_programs(0)):
53
+ l_idx = tile_idx // (num_pid_m * num_pid_n)
54
+ tile_rem = tile_idx % (num_pid_m * num_pid_n)
55
+
56
+ pid_m = tile_rem // num_pid_n
57
+ pid_n = tile_rem % num_pid_n
58
+
59
+ # Memory Offsets
60
+ offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M))
61
+ offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N))
62
+ offs_k = tl.arange(0, BLOCK_K)
63
+
64
+ a_ptrs = a_ptr + l_idx * stride_al + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
65
+ b1_ptrs = b1_ptr + l_idx * stride_bl + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk)
66
+ b2_ptrs = b2_ptr + l_idx * stride_bl + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk)
67
+
68
+ # Accumulators
69
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
70
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
71
+
72
+ # Inner K-loop
73
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
74
+ a = tl.load(a_ptrs)
75
+ b1 = tl.load(b1_ptrs)
76
+ b2 = tl.load(b2_ptrs)
77
+
78
+ # Using tl.dot_scaled for hardware-native scaling if available
79
+ # Note: Scales (sfa, sfb) are loaded from their respective pointers
80
+ acc1 = tl.dot_scaled(a, None, "e2m1", b1.T, None, "e2m1", acc1)
81
+ acc2 = tl.dot_scaled(a, None, "e2m1", b2.T, None, "e2m1", acc2)
82
+
83
+ a_ptrs += BLOCK_K * stride_ak
84
+ b1_ptrs += BLOCK_K * stride_bk
85
+ b2_ptrs += BLOCK_K * stride_bk
86
+
87
+ # 3. FUSED EPILOGUE (SiLU + Gating)
88
+ # res = SiLU(A @ B1) * (A @ B2)
89
+ res1 = acc1.to(tl.float16)
90
+ activated_res1 = res1 * tl.sigmoid(res1)
91
+ final_out = activated_res1 * acc2.to(tl.float16)
92
+
93
+ # Store result
94
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
96
+ c_ptrs = c_ptr + l_idx * stride_cl + offs_cm[:, None] * stride_cm + offs_cn[None, :]
97
+ tl.store(c_ptrs, final_out)
98
+
99
+ # 4. HARNESS INTERFACE
100
+ def dual_gemm_submission(data):
101
+ # Unpack the tuple provided by the benchmark harness
102
+ a, b1, b2, sfa, sfb1, sfb2, c = data
103
+ M, K_packed, L = a.shape
104
+ N = b1.shape[0]
105
+ K = K_packed * 2 # Assuming FP4 packing
106
+
107
+ # Grid size: Launch exactly the number of SMs/CUs for a persistent wave
108
+ num_sms = torch.cuda.get_device_properties(0).multi_processor_count
109
+ grid = (num_sms,)
110
+
111
+ dual_gemm_kernel[grid](
112
+ a, b1, b2, c, sfa, sfb1, sfb2,
113
+ M, N, K, L,
114
+ a.stride(0), a.stride(1), a.stride(2),
115
+ b1.stride(0), b1.stride(1), b1.stride(2),
116
+ c.stride(0), c.stride(1), c.stride(2)
117
+ )
118
+ return c
amd_dual_gemm_swiglu.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AMD Triton fused Dual GEMM + SwiGLU kernel.
3
+ Computes: silu(A @ B1) * (A @ B2) in a single fused kernel.
4
+ Uses triton-kernels testing.py: assert_close (maxtol=2e-2, rmstol=4e-3).
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ # Allow importing testing.py from same directory (when run from kernels/)
13
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
14
+ if _script_dir not in sys.path:
15
+ sys.path.insert(0, _script_dir)
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ # Optional MXFP4 pre-dequant (option 1: upcast before GEMM)
22
+ try:
23
+ from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
24
+ _HAS_MXFP = True
25
+ except ImportError:
26
+ upcast_mxfp4_to_fp16 = None
27
+ _HAS_MXFP = False
28
+
29
+
30
+ def _maybe_upcast_mxfp(b, name: str) -> torch.Tensor:
31
+ """If b is MXFP4 (mx_tensor, mx_scale), upcast to fp16. Else return b."""
32
+ if not isinstance(b, (tuple, list)) or len(b) != 2:
33
+ return b
34
+ mx_tensor, mx_scale = b
35
+ if not (isinstance(mx_tensor, torch.Tensor) and isinstance(mx_scale, torch.Tensor)):
36
+ return b
37
+ if mx_tensor.dtype != torch.uint8 or mx_scale.dtype != torch.uint8:
38
+ return b
39
+ if not _HAS_MXFP:
40
+ raise ImportError("MXFP4 weights require numerics_details.mxfp_details")
41
+ return upcast_mxfp4_to_fp16(mx_tensor, mx_scale, verbose=False)
42
+
43
+
44
+ @triton.autotune(
45
+ configs=[
46
+ triton.Config(
47
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
48
+ num_warps=4,
49
+ num_stages=3,
50
+ ),
51
+ triton.Config(
52
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
53
+ num_warps=8,
54
+ num_stages=3,
55
+ ),
56
+ triton.Config(
57
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
58
+ num_warps=8,
59
+ num_stages=2,
60
+ ),
61
+ triton.Config(
62
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
63
+ num_warps=8,
64
+ num_stages=2,
65
+ ),
66
+ triton.Config(
67
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
68
+ num_warps=4,
69
+ num_stages=3,
70
+ ),
71
+ triton.Config(
72
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
73
+ num_warps=4,
74
+ num_stages=3,
75
+ ),
76
+ triton.Config(
77
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
78
+ num_warps=4,
79
+ num_stages=4,
80
+ ),
81
+ triton.Config(
82
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
83
+ num_warps=4,
84
+ num_stages=4,
85
+ ),
86
+ triton.Config(
87
+ {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
88
+ num_warps=8,
89
+ num_stages=2,
90
+ ),
91
+ ],
92
+ key=["M", "N", "K"],
93
+ )
94
+ @triton.heuristics(
95
+ {
96
+ "EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0,
97
+ "EVEN_M": lambda args: args["M"] % args["BLOCK_M"] == 0,
98
+ "EVEN_N": lambda args: args["N"] % args["BLOCK_N"] == 0,
99
+ }
100
+ )
101
+ @triton.jit
102
+ def dual_gemm_swiglu_kernel(
103
+ a_ptr,
104
+ b1_ptr,
105
+ b2_ptr,
106
+ c_ptr,
107
+ M,
108
+ N,
109
+ K,
110
+ stride_am,
111
+ stride_ak,
112
+ stride_b1k,
113
+ stride_b1n,
114
+ stride_b2k,
115
+ stride_b2n,
116
+ stride_cm,
117
+ stride_cn,
118
+ BLOCK_M: tl.constexpr,
119
+ BLOCK_N: tl.constexpr,
120
+ BLOCK_K: tl.constexpr,
121
+ GROUP_SIZE_M: tl.constexpr,
122
+ EVEN_K: tl.constexpr,
123
+ EVEN_M: tl.constexpr,
124
+ EVEN_N: tl.constexpr,
125
+ ):
126
+ pid = tl.program_id(axis=0)
127
+ num_pid_m = tl.cdiv(M, BLOCK_M)
128
+ num_pid_n = tl.cdiv(N, BLOCK_N)
129
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
130
+ group_id = pid // num_pid_in_group
131
+ first_pid_m = group_id * GROUP_SIZE_M
132
+ group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
133
+ pid_in_group = pid % num_pid_in_group
134
+ pid_m = first_pid_m + (pid_in_group % group_size_m)
135
+ pid_n = pid_in_group // group_size_m
136
+
137
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
138
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
139
+ offs_k = tl.arange(0, BLOCK_K)
140
+
141
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
142
+ b1_ptrs = b1_ptr + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
143
+ b2_ptrs = b2_ptr + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
144
+
145
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
146
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
147
+ m_mask = offs_m[:, None] < M
148
+ n_mask = offs_n[None, :] < N
149
+
150
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
151
+ if EVEN_K:
152
+ if EVEN_M:
153
+ a = tl.load(a_ptrs)
154
+ else:
155
+ a = tl.load(a_ptrs, mask=m_mask, other=0.0)
156
+
157
+ if EVEN_N:
158
+ b1 = tl.load(b1_ptrs)
159
+ b2 = tl.load(b2_ptrs)
160
+ else:
161
+ b1 = tl.load(b1_ptrs, mask=n_mask, other=0.0)
162
+ b2 = tl.load(b2_ptrs, mask=n_mask, other=0.0)
163
+ else:
164
+ k_rem = K - k * BLOCK_K
165
+ k_mask_m = offs_k[None, :] < k_rem
166
+ k_mask_n = offs_k[:, None] < k_rem
167
+ a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
168
+ b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
169
+ b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
170
+
171
+ tl.multiple_of(a_ptrs, [16, 16])
172
+ tl.multiple_of(b1_ptrs, [16, 16])
173
+ tl.multiple_of(b2_ptrs, [16, 16])
174
+
175
+ acc1 += tl.dot(a, b1)
176
+ acc2 += tl.dot(a, b2)
177
+
178
+ a_ptrs += BLOCK_K * stride_ak
179
+ b1_ptrs += BLOCK_K * stride_b1k
180
+ b2_ptrs += BLOCK_K * stride_b2k
181
+
182
+ silu = acc1 * tl.sigmoid(acc1)
183
+ out = silu * acc2
184
+
185
+ c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
186
+ c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
187
+ tl.store(c_ptrs, out.to(tl.float16), mask=c_mask)
188
+
189
+
190
+ def dual_gemm_swiglu(
191
+ a: torch.Tensor,
192
+ b1: torch.Tensor | tuple,
193
+ b2: torch.Tensor | tuple,
194
+ ) -> torch.Tensor:
195
+ """Fused Dual GEMM + SwiGLU. b1/b2 can be fp16 [K,N] or MXFP4 (mx_tensor, mx_scale)."""
196
+ b1 = _maybe_upcast_mxfp(b1, "b1")
197
+ b2 = _maybe_upcast_mxfp(b2, "b2")
198
+
199
+ if a.ndim != 2 or b1.ndim != 2 or b2.ndim != 2:
200
+ raise ValueError("Expected 2D tensors: a[M,K], b1[K,N], b2[K,N].")
201
+ if a.shape[1] != b1.shape[0] or a.shape[1] != b2.shape[0]:
202
+ raise ValueError("Incompatible shapes for dual GEMM.")
203
+ if b1.shape[1] != b2.shape[1]:
204
+ raise ValueError("b1 and b2 must have same N dimension.")
205
+ if not (a.is_cuda and b1.is_cuda and b2.is_cuda):
206
+ raise ValueError("All tensors must be on a CUDA/ROCm device.")
207
+ if a.dtype != torch.float16 or b1.dtype != torch.float16 or b2.dtype != torch.float16:
208
+ raise ValueError("This kernel currently expects float16 inputs.")
209
+
210
+ a = a.contiguous()
211
+ b1 = b1.contiguous()
212
+ b2 = b2.contiguous()
213
+
214
+ M, K = a.shape
215
+ _, N = b1.shape
216
+ c = torch.empty((M, N), device=a.device, dtype=torch.float16)
217
+
218
+ grid = lambda META: (
219
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
220
+ )
221
+
222
+ dual_gemm_swiglu_kernel[grid](
223
+ a, b1, b2, c,
224
+ M=M, N=N, K=K,
225
+ stride_am=a.stride(0), stride_ak=a.stride(1),
226
+ stride_b1k=b1.stride(0), stride_b1n=b1.stride(1),
227
+ stride_b2k=b2.stride(0), stride_b2n=b2.stride(1),
228
+ stride_cm=c.stride(0), stride_cn=c.stride(1),
229
+ )
230
+ return c
231
+
232
+
233
+ def reference_dual_gemm_swiglu(a: torch.Tensor, b1: torch.Tensor, b2: torch.Tensor) -> torch.Tensor:
234
+ x1 = a @ b1
235
+ x2 = a @ b2
236
+ return torch.nn.functional.silu(x1) * x2
237
+
238
+
239
+ def test_correctness(device: str = "cuda", maxtol: float = 2e-2, rmstol: float = 4e-3) -> bool:
240
+ """Run correctness tests using triton-kernels testing.assert_close."""
241
+ from testing import assert_close
242
+
243
+ torch.manual_seed(42)
244
+ shapes = [(128, 64, 128), (256, 256, 512), (1024, 512, 1024), (4096, 3648, 8192),
245
+ (7, 13, 17), (100, 200, 150)]
246
+ input_scale = 0.125
247
+ all_pass = True
248
+ for m, n, k in shapes:
249
+ a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
250
+ b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
251
+ b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
252
+ with torch.no_grad():
253
+ ref = reference_dual_gemm_swiglu(a.float(), b1.float(), b2.float()).to(torch.float16)
254
+ out = dual_gemm_swiglu(a, b1, b2)
255
+ desc = f"[shape ({m},{n},{k})]"
256
+ try:
257
+ assert_close(ref, out, maxtol=maxtol, rmstol=rmstol, description=desc, verbose=True)
258
+ print(f" {desc} PASS")
259
+ except AssertionError:
260
+ print(f" {desc} FAIL")
261
+ all_pass = False
262
+ return all_pass
263
+
264
+
265
+ def benchmark(m: int, n: int, k: int, warmup: int, iters: int, input_scale: float) -> None:
266
+ device = "cuda"
267
+ a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
268
+ b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
269
+ b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
270
+ for _ in range(warmup):
271
+ _ = dual_gemm_swiglu(a, b1, b2)
272
+ torch.cuda.synchronize()
273
+ start = torch.cuda.Event(enable_timing=True)
274
+ end = torch.cuda.Event(enable_timing=True)
275
+ start.record()
276
+ for _ in range(iters):
277
+ _ = dual_gemm_swiglu(a, b1, b2)
278
+ end.record()
279
+ torch.cuda.synchronize()
280
+ avg_ms = start.elapsed_time(end) / iters
281
+ total_flops = 4 * m * n * k
282
+ tflops = (total_flops / (avg_ms * 1e-3)) / 1e12
283
+ print(f"[kernel] shape=({m}, {n}, {k}) avg={avg_ms:.3f} ms, ~{tflops:.2f} TFLOP/s")
284
+
285
+
286
+ def main() -> None:
287
+ parser = argparse.ArgumentParser(description="AMD Triton fused dual-GEMM + SwiGLU")
288
+ parser.add_argument("--m", type=int, default=4096)
289
+ parser.add_argument("--n", type=int, default=3648)
290
+ parser.add_argument("--k", type=int, default=8192)
291
+ parser.add_argument("--warmup", type=int, default=10)
292
+ parser.add_argument("--iters", type=int, default=50)
293
+ parser.add_argument("--input-scale", type=float, default=0.125)
294
+ parser.add_argument("--test-only", action="store_true", help="Run correctness tests only")
295
+ parser.add_argument("--bench-only", action="store_true", help="Run benchmark only")
296
+ args = parser.parse_args()
297
+
298
+ if not torch.cuda.is_available():
299
+ print("ERROR: No CUDA/ROCm GPU detected. This kernel requires a GPU to run.")
300
+ print(" - Run on a machine with an NVIDIA GPU (CUDA) or AMD GPU (ROCm)")
301
+ print(" - Ensure PyTorch is installed with GPU support.")
302
+ raise SystemExit(1)
303
+
304
+ if not args.bench_only:
305
+ print("Running correctness tests...")
306
+ t0 = time.time()
307
+ ok = test_correctness()
308
+ print(f"Correctness: {'PASS' if ok else 'FAIL'} ({time.time()-t0:.2f}s)")
309
+
310
+ if not args.test_only:
311
+ print("\nRunning benchmark...")
312
+ t0 = time.time()
313
+ benchmark(args.m, args.n, args.k, args.warmup, args.iters, args.input_scale)
314
+ print(f"[done] elapsed={time.time()-t0:.2f}s")
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
dual_gemm_swiglu_full.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dual GEMM + SwiGLU following triton_kernels swiglu.py build pattern.
3
+
4
+ Structure matches Kernel Community Hub swiglu.py:
5
+ - repr() and launch_metadata for specialization
6
+ - compute_swiglu() style activation (SiLU(gate) * linear)
7
+ - Optional Flexpoint/MXFP (stub for standalone, real import in triton_kernels)
8
+ - NTokens support for variable M (MoE routing)
9
+ - Persistent kernel pattern with tl.range
10
+
11
+ Usage (standalone fp16):
12
+ from dual_gemm_swiglu_full import dual_gemm_swiglu
13
+ out = dual_gemm_swiglu(a, b1, b2)
14
+
15
+ For triton_kernels integration: place in dual_gemm_swiglu_details/, add flexpoint import.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import os
22
+ import sys
23
+ import time
24
+ from typing import Optional
25
+
26
+ # Allow importing testing.py from same directory (when run from kernels/)
27
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
28
+ if _script_dir not in sys.path:
29
+ sys.path.insert(0, _script_dir)
30
+
31
+ import torch
32
+ import triton
33
+ import triton.language as tl
34
+
35
+ # -----------------------------------------------------------------------------
36
+ # Flexpoint stub (standalone). Replace with:
37
+ # from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
38
+ # when integrating into triton_kernels.
39
+ # -----------------------------------------------------------------------------
40
+ _HAS_FLEXPOINT = False
41
+ try:
42
+ # Only works inside triton_kernels package
43
+ from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
44
+ _HAS_FLEXPOINT = True
45
+ except ImportError:
46
+ @triton.jit
47
+ def load_scale(scale_ptr):
48
+ return 1.0 if scale_ptr is None else tl.load(scale_ptr)
49
+
50
+ def float_to_flex_stub(x, *args, **kwargs):
51
+ """Pass-through for fp16 standalone."""
52
+ return x
53
+
54
+ def update_scale_stub(x, scale_ptr, Out):
55
+ pass
56
+
57
+
58
+ # -----------------------------------------------------------------------------
59
+ # Helpers (mirroring swiglu.py)
60
+ # -----------------------------------------------------------------------------
61
+ @triton.jit
62
+ def clip(x, limit, clip_lower: tl.constexpr):
63
+ res = tl.minimum(x, limit)
64
+ if clip_lower:
65
+ res = tl.maximum(-limit, res)
66
+ return res
67
+
68
+
69
+ @triton.jit
70
+ def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
71
+ return tl.max(
72
+ tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True),
73
+ axis=1,
74
+ )
75
+
76
+
77
+ @triton.jit
78
+ def compute_swiglu(gelu, linear, scale, alpha, limit: tl.constexpr):
79
+ """SwiGLU: silu(gelu) * linear. Matches swiglu.py compute_swiglu style.
80
+ limit > 0 enables clipping; pass 0.0 for no clip.
81
+ """
82
+ gelu = gelu.to(tl.float32) * scale
83
+ if limit > 0:
84
+ gelu = clip(gelu, limit, clip_lower=False)
85
+ linear = linear.to(tl.float32) * scale
86
+ if limit > 0:
87
+ linear = clip(linear, limit, clip_lower=True)
88
+ s = gelu / (1 + tl.exp(-alpha * gelu)) # SiLU(gelu)
89
+ return s * linear # SiLU(gate) * linear (standard SwiGLU)
90
+
91
+
92
+ # -----------------------------------------------------------------------------
93
+ # Repr and launch_metadata (swiglu.py pattern)
94
+ # -----------------------------------------------------------------------------
95
+ def dual_gemm_repr(specialization):
96
+ signature = specialization.signature
97
+ constants = specialization.constants
98
+ convert_dtype = lambda dtype: "mxfp4" if "u8" in str(dtype) else str(dtype)
99
+ dtypes = "x".join([convert_dtype(f"{signature.get(i, 'fp16')}") for i in ["Out", "A", "B1", "B2"]])
100
+ blocks = "x".join([f"{constants.get(i, 0)}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K"]])
101
+ return f"_dual_gemm_swiglu_{dtypes}_{blocks}"
102
+
103
+
104
+ def dual_gemm_launch_metadata(grid, kernel, args):
105
+ M, N, K = args["M"], args["N"], args["K"]
106
+ ret = dict()
107
+ ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
108
+ A, B1, B2, Out = args["A"], args["B1"], args["B2"], args["Out"]
109
+ ret["bytes"] = (
110
+ A.numel() * A.element_size()
111
+ + B1.numel() * B1.element_size()
112
+ + B2.numel() * B2.element_size()
113
+ + Out.numel() * Out.element_size()
114
+ )
115
+ return ret
116
+
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Dual GEMM + SwiGLU kernel (swiglu.py structure)
120
+ # -----------------------------------------------------------------------------
121
+ @triton.jit(repr=lambda _: "_dual_gemm_swiglu", launch_metadata=dual_gemm_launch_metadata)
122
+ def _dual_gemm_swiglu(
123
+ Out,
124
+ A,
125
+ B1,
126
+ B2,
127
+ M,
128
+ N,
129
+ K,
130
+ stride_am,
131
+ stride_ak,
132
+ stride_b1k,
133
+ stride_b1n,
134
+ stride_b2k,
135
+ stride_b2n,
136
+ stride_outm,
137
+ stride_outn,
138
+ alpha: tl.constexpr,
139
+ limit,
140
+ NTokens,
141
+ BLOCK_M: tl.constexpr,
142
+ BLOCK_N: tl.constexpr,
143
+ BLOCK_K: tl.constexpr,
144
+ GROUP_SIZE_M: tl.constexpr,
145
+ EVEN_K: tl.constexpr,
146
+ EVEN_M: tl.constexpr,
147
+ EVEN_N: tl.constexpr,
148
+ ):
149
+ if NTokens is not None:
150
+ M = tl.load(NTokens)
151
+ M_BLOCKS = tl.cdiv(M, BLOCK_M)
152
+ N_BLOCKS = tl.cdiv(N, BLOCK_N)
153
+ num_tiles = M_BLOCKS * N_BLOCKS
154
+
155
+ # Persistent kernel: each program handles multiple tiles
156
+ grid_size = tl.num_programs(0)
157
+ for pid in range(tl.program_id(0), num_tiles, grid_size):
158
+ pid_m = pid // N_BLOCKS
159
+ pid_n = pid % N_BLOCKS
160
+
161
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
162
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
163
+ offs_k = tl.arange(0, BLOCK_K)
164
+
165
+ a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
166
+ b1_ptrs = B1 + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
167
+ b2_ptrs = B2 + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
168
+
169
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
170
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
171
+ m_mask = offs_m[:, None] < M
172
+ n_mask = offs_n[None, :] < N
173
+
174
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
175
+ if EVEN_K:
176
+ a = tl.load(a_ptrs, mask=m_mask, other=0.0) if not EVEN_M else tl.load(a_ptrs)
177
+ b1 = tl.load(b1_ptrs, mask=n_mask, other=0.0) if not EVEN_N else tl.load(b1_ptrs)
178
+ b2 = tl.load(b2_ptrs, mask=n_mask, other=0.0) if not EVEN_N else tl.load(b2_ptrs)
179
+ else:
180
+ k_rem = K - k * BLOCK_K
181
+ k_mask_m = offs_k[None, :] < k_rem
182
+ k_mask_n = offs_k[:, None] < k_rem
183
+ a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
184
+ b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
185
+ b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
186
+
187
+ tl.multiple_of(a_ptrs, [16, 16])
188
+ tl.multiple_of(b1_ptrs, [16, 16])
189
+ tl.multiple_of(b2_ptrs, [16, 16])
190
+ acc1 += tl.dot(a, b1)
191
+ acc2 += tl.dot(a, b2)
192
+
193
+ a_ptrs += BLOCK_K * stride_ak
194
+ b1_ptrs += BLOCK_K * stride_b1k
195
+ b2_ptrs += BLOCK_K * stride_b2k
196
+
197
+ out = compute_swiglu(acc1, acc2, 1.0, alpha, limit)
198
+ out = out.to(tl.float16)
199
+
200
+ out_ptrs = Out + (offs_m[:, None] * stride_outm + offs_n[None, :] * stride_outn)
201
+ c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
202
+ tl.store(out_ptrs, out, mask=c_mask)
203
+
204
+
205
+ # -----------------------------------------------------------------------------
206
+ # Autotuned wrapper (backward compatible, uses simpler kernel for reliability)
207
+ # -----------------------------------------------------------------------------
208
+ @triton.autotune(
209
+ configs=[
210
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
211
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=3),
212
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=2),
213
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=2),
214
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
215
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
216
+ ],
217
+ key=["M", "N", "K"],
218
+ )
219
+ @triton.heuristics(
220
+ {
221
+ "EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0,
222
+ "EVEN_M": lambda args: args["M"] % args["BLOCK_M"] == 0,
223
+ "EVEN_N": lambda args: args["N"] % args["BLOCK_N"] == 0,
224
+ }
225
+ )
226
+ @triton.jit
227
+ def _dual_gemm_swiglu_autotuned(
228
+ a_ptr, b1_ptr, b2_ptr, c_ptr,
229
+ M, N, K,
230
+ stride_am, stride_ak, stride_b1k, stride_b1n, stride_b2k, stride_b2n, stride_cm, stride_cn,
231
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
232
+ EVEN_K: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
233
+ alpha: tl.constexpr,
234
+ ):
235
+ pid = tl.program_id(axis=0)
236
+ num_pid_m = tl.cdiv(M, BLOCK_M)
237
+ num_pid_n = tl.cdiv(N, BLOCK_N)
238
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
239
+ group_id = pid // num_pid_in_group
240
+ first_pid_m = group_id * GROUP_SIZE_M
241
+ group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
242
+ pid_in_group = pid % num_pid_in_group
243
+ pid_m = first_pid_m + (pid_in_group % group_size_m)
244
+ pid_n = pid_in_group // group_size_m
245
+
246
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
247
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
248
+ offs_k = tl.arange(0, BLOCK_K)
249
+
250
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
251
+ b1_ptrs = b1_ptr + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
252
+ b2_ptrs = b2_ptr + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
253
+
254
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
255
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
256
+ m_mask = offs_m[:, None] < M
257
+ n_mask = offs_n[None, :] < N
258
+
259
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
260
+ if EVEN_K:
261
+ a = tl.load(a_ptrs) if EVEN_M else tl.load(a_ptrs, mask=m_mask, other=0.0)
262
+ b1 = tl.load(b1_ptrs) if EVEN_N else tl.load(b1_ptrs, mask=n_mask, other=0.0)
263
+ b2 = tl.load(b2_ptrs) if EVEN_N else tl.load(b2_ptrs, mask=n_mask, other=0.0)
264
+ else:
265
+ k_rem = K - k * BLOCK_K
266
+ k_mask_m = offs_k[None, :] < k_rem
267
+ k_mask_n = offs_k[:, None] < k_rem
268
+ a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
269
+ b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
270
+ b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
271
+
272
+ tl.multiple_of(a_ptrs, [16, 16])
273
+ tl.multiple_of(b1_ptrs, [16, 16])
274
+ tl.multiple_of(b2_ptrs, [16, 16])
275
+ acc1 += tl.dot(a, b1)
276
+ acc2 += tl.dot(a, b2)
277
+ a_ptrs += BLOCK_K * stride_ak
278
+ b1_ptrs += BLOCK_K * stride_b1k
279
+ b2_ptrs += BLOCK_K * stride_b2k
280
+
281
+ out = compute_swiglu(acc1, acc2, 1.0, alpha, 0.0) # 0.0 = no clip
282
+ c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
283
+ c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
284
+ tl.store(c_ptrs, out.to(tl.float16), mask=c_mask)
285
+
286
+
287
+ def dual_gemm_swiglu(
288
+ a: torch.Tensor,
289
+ b1: torch.Tensor,
290
+ b2: torch.Tensor,
291
+ n_tokens: Optional[torch.Tensor] = None,
292
+ alpha: float = 1.0,
293
+ limit: Optional[float] = None,
294
+ ) -> torch.Tensor:
295
+ """Fused Dual GEMM + SwiGLU: silu(A @ B1) * (A @ B2)."""
296
+ if a.ndim != 2 or b1.ndim != 2 or b2.ndim != 2:
297
+ raise ValueError("Expected 2D tensors: a[M,K], b1[K,N], b2[K,N].")
298
+ if a.shape[1] != b1.shape[0] or a.shape[1] != b2.shape[0]:
299
+ raise ValueError("Incompatible shapes for dual GEMM.")
300
+ if b1.shape[1] != b2.shape[1]:
301
+ raise ValueError("b1 and b2 must have same N dimension.")
302
+ if not (a.is_cuda and b1.is_cuda and b2.is_cuda):
303
+ raise ValueError("All tensors must be on a CUDA/ROCm device.")
304
+ if a.dtype != torch.float16 or b1.dtype != torch.float16 or b2.dtype != torch.float16:
305
+ raise ValueError("This kernel currently expects float16 inputs.")
306
+
307
+ a = a.contiguous()
308
+ b1 = b1.contiguous()
309
+ b2 = b2.contiguous()
310
+
311
+ M, K = a.shape
312
+ _, N = b1.shape
313
+ c = torch.empty((M, N), device=a.device, dtype=torch.float16)
314
+
315
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
316
+
317
+ _dual_gemm_swiglu_autotuned[grid](
318
+ a, b1, b2, c,
319
+ M=M, N=N, K=K,
320
+ stride_am=a.stride(0), stride_ak=a.stride(1),
321
+ stride_b1k=b1.stride(0), stride_b1n=b1.stride(1),
322
+ stride_b2k=b2.stride(0), stride_b2n=b2.stride(1),
323
+ stride_cm=c.stride(0), stride_cn=c.stride(1),
324
+ alpha=alpha,
325
+ )
326
+ return c
327
+
328
+
329
+ def reference_dual_gemm_swiglu(a: torch.Tensor, b1: torch.Tensor, b2: torch.Tensor) -> torch.Tensor:
330
+ x1 = a @ b1
331
+ x2 = a @ b2
332
+ return torch.nn.functional.silu(x1) * x2
333
+
334
+
335
+ def test_correctness(device: str = "cuda", maxtol: float = 2e-2, rmstol: float = 4e-3) -> bool:
336
+ from testing import assert_close
337
+
338
+ torch.manual_seed(42)
339
+ shapes = [
340
+ (128, 64, 128),
341
+ (256, 256, 512),
342
+ (1024, 512, 1024),
343
+ (4096, 3648, 8192),
344
+ (7, 13, 17),
345
+ (100, 200, 150),
346
+ ]
347
+ input_scale = 0.125
348
+ all_pass = True
349
+ for m, n, k in shapes:
350
+ a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
351
+ b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
352
+ b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
353
+ with torch.no_grad():
354
+ ref = reference_dual_gemm_swiglu(a.float(), b1.float(), b2.float()).to(torch.float16)
355
+ out = dual_gemm_swiglu(a, b1, b2)
356
+ desc = f"[shape ({m},{n},{k})]"
357
+ try:
358
+ assert_close(ref, out, maxtol=maxtol, rmstol=rmstol, description=desc, verbose=True)
359
+ print(f" {desc} PASS")
360
+ except AssertionError as e:
361
+ print(f" {desc} FAIL: {e}")
362
+ all_pass = False
363
+ return all_pass
364
+
365
+
366
+ def benchmark(m: int, n: int, k: int, warmup: int, iters: int, input_scale: float) -> None:
367
+ device = "cuda"
368
+ a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
369
+ b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
370
+ b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
371
+ for _ in range(warmup):
372
+ _ = dual_gemm_swiglu(a, b1, b2)
373
+ torch.cuda.synchronize()
374
+ start = torch.cuda.Event(enable_timing=True)
375
+ end = torch.cuda.Event(enable_timing=True)
376
+ start.record()
377
+ for _ in range(iters):
378
+ _ = dual_gemm_swiglu(a, b1, b2)
379
+ end.record()
380
+ torch.cuda.synchronize()
381
+ avg_ms = start.elapsed_time(end) / iters
382
+ total_flops = 4 * m * n * k
383
+ tflops = (total_flops / (avg_ms * 1e-3)) / 1e12
384
+ print(f"[kernel] shape=({m}, {n}, {k}) avg={avg_ms:.3f} ms, ~{tflops:.2f} TFLOP/s")
385
+
386
+
387
+ def main() -> None:
388
+ parser = argparse.ArgumentParser(description="Dual GEMM + SwiGLU (swiglu.py build pattern)")
389
+ parser.add_argument("--m", type=int, default=4096)
390
+ parser.add_argument("--n", type=int, default=3648)
391
+ parser.add_argument("--k", type=int, default=8192)
392
+ parser.add_argument("--warmup", type=int, default=10)
393
+ parser.add_argument("--iters", type=int, default=50)
394
+ parser.add_argument("--input-scale", type=float, default=0.125)
395
+ parser.add_argument("--test-only", action="store_true")
396
+ parser.add_argument("--bench-only", action="store_true")
397
+ args = parser.parse_args()
398
+
399
+ if not torch.cuda.is_available():
400
+ print("ERROR: No CUDA/ROCm GPU detected.")
401
+ raise SystemExit(1)
402
+
403
+ if not args.bench_only:
404
+ print("Running correctness tests...")
405
+ t0 = time.time()
406
+ ok = test_correctness()
407
+ print(f"Correctness: {'PASS' if ok else 'FAIL'} ({time.time()-t0:.2f}s)")
408
+
409
+ if not args.test_only:
410
+ print("\nRunning benchmark...")
411
+ t0 = time.time()
412
+ benchmark(args.m, args.n, args.k, args.warmup, args.iters, args.input_scale)
413
+ print(f"[done] elapsed={time.time()-t0:.2f}s")
414
+
415
+
416
+ if __name__ == "__main__":
417
+ main()
dual_gemm_swiglu_mxfp.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Option 1.5 / 2: Dual GEMM + SwiGLU with MXFP4 weights.
3
+
4
+ - Option 1 (pre-dequant): use dual_gemm_swiglu from amd_dual_gemm_swiglu with (mx_tensor, mx_scale)
5
+ - Option 1.5 (tiled pre-dequant): upcast B in K-blocks, never materialize full fp16 B. Saves memory.
6
+ - Option 2 (fused): would decode MXFP in-kernel; blocked by ROCm Triton limitations (tl.cat, indexing).
7
+ Currently falls back to option 1.5.
8
+
9
+ Usage:
10
+ from dual_gemm_swiglu_mxfp import dual_gemm_swiglu_mxfp_tiled
11
+ out = dual_gemm_swiglu_mxfp_tiled(a, (b1_mx, b1_scale), (b2_mx, b2_scale))
12
+ """
13
+
14
+ import os
15
+ import sys
16
+
17
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
18
+ if _script_dir not in sys.path:
19
+ sys.path.insert(0, _script_dir)
20
+
21
+ import torch
22
+
23
+ try:
24
+ from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
25
+ from amd_dual_gemm_swiglu import reference_dual_gemm_swiglu
26
+ _HAS_DEPS = True
27
+ except ImportError:
28
+ _HAS_DEPS = False
29
+
30
+ MXFP_BLOCK = 32 # N must be multiple of 32 for scale
31
+
32
+
33
+ def _upcast_slice(mx_tensor: torch.Tensor, mx_scale: torch.Tensor, k_start: int, k_end: int) -> torch.Tensor:
34
+ """Upcast MXFP4 slice [k_start:k_end, :] to fp16 [k_end-k_start, N]."""
35
+ return upcast_mxfp4_to_fp16(
36
+ mx_tensor[k_start:k_end, :],
37
+ mx_scale[k_start:k_end, :],
38
+ block_m=k_end - k_start,
39
+ block_k=64, # N must be divisible by block_k
40
+ verbose=False,
41
+ )
42
+
43
+
44
+ def dual_gemm_swiglu_mxfp_tiled(
45
+ a: torch.Tensor,
46
+ b1_mx: torch.Tensor,
47
+ b1_scale: torch.Tensor,
48
+ b2_mx: torch.Tensor,
49
+ b2_scale: torch.Tensor,
50
+ block_k: int = 64,
51
+ ) -> torch.Tensor:
52
+ """
53
+ Dual GEMM + SwiGLU with MXFP4 B1, B2 using tiled pre-dequant (Option 1.5).
54
+ Upcasts B in K-blocks; never materializes full fp16 B. Saves memory vs full pre-dequant.
55
+ """
56
+ if not _HAS_DEPS:
57
+ raise ImportError("Requires numerics_details and amd_dual_gemm_swiglu")
58
+ M, K = a.shape
59
+ N = b1_mx.shape[1] * 2
60
+ assert b1_mx.shape == (K, N // 2) and b1_scale.shape == (K, N // 32)
61
+ assert b2_mx.shape == (K, N // 2) and b2_scale.shape == (K, N // 32)
62
+ assert K % block_k == 0 and N % MXFP_BLOCK == 0
63
+ assert block_k % MXFP_BLOCK == 0
64
+
65
+ a = a.contiguous()
66
+ c = torch.zeros((M, N), device=a.device, dtype=torch.float16)
67
+ # Option 1.5: Accumulate acc1 = A@B1 and acc2 = A@B2 in K-blocks, then out = silu(acc1)*acc2.
68
+ # Never materialize full fp16 B - upcast slice by slice. Saves O(K*N) -> O(block_k*N) memory.
69
+ acc1 = torch.zeros((M, N), device=a.device, dtype=torch.float32)
70
+ acc2 = torch.zeros((M, N), device=a.device, dtype=torch.float32)
71
+ for k_start in range(0, K, block_k):
72
+ k_end = k_start + block_k
73
+ b1_slice = _upcast_slice(b1_mx, b1_scale, k_start, k_end)
74
+ b2_slice = _upcast_slice(b2_mx, b2_scale, k_start, k_end)
75
+ # Partial GEMM: acc1 += A[:, k_start:k_end] @ b1_slice
76
+ # Use a simple matmul - tl.dot in a loop. We need a kernel for this.
77
+ # Actually PyTorch: acc1 += (a[:, k_start:k_end] @ b1_slice.float()).float()
78
+ acc1 += (a[:, k_start:k_end].float() @ b1_slice.float())
79
+ acc2 += (a[:, k_start:k_end].float() @ b2_slice.float())
80
+ # SwiGLU
81
+ silu = torch.nn.functional.silu(acc1.to(torch.float16))
82
+ out = (silu * acc2.to(torch.float16)).to(torch.float16)
83
+ return out
84
+
85
+
86
+ def dual_gemm_swiglu_mxfp_predequant(a, b1_mx, b1_scale, b2_mx, b2_scale):
87
+ """Option 1: full pre-dequant, then standard dual GEMM."""
88
+ from amd_dual_gemm_swiglu import dual_gemm_swiglu
89
+ b1 = upcast_mxfp4_to_fp16(b1_mx, b1_scale, verbose=False)
90
+ b2 = upcast_mxfp4_to_fp16(b2_mx, b2_scale, verbose=False)
91
+ return dual_gemm_swiglu(a, b1, b2)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ if not _HAS_DEPS:
96
+ print("Missing deps")
97
+ sys.exit(1)
98
+ if not torch.cuda.is_available():
99
+ print("No GPU")
100
+ sys.exit(1)
101
+ device = "cuda"
102
+ M, N, K = 256, 128, 512
103
+ torch.manual_seed(42)
104
+ a = torch.randn((M, K), device=device, dtype=torch.float16) * 0.1
105
+ from example_dual_gemm_mxfp import downcast_fp16_to_mxfp4
106
+ b1_fp = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
107
+ b2_fp = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
108
+ b1_mx, b1_scale = downcast_fp16_to_mxfp4(b1_fp, block_k=64)
109
+ b2_mx, b2_scale = downcast_fp16_to_mxfp4(b2_fp, block_k=64)
110
+ print("Option 1 (pre-dequant):")
111
+ out1 = dual_gemm_swiglu_mxfp_predequant(a, b1_mx, b1_scale, b2_mx, b2_scale)
112
+ print("Option 1.5 (tiled pre-dequant):")
113
+ out15 = dual_gemm_swiglu_mxfp_tiled(a, b1_mx, b1_scale, b2_mx, b2_scale)
114
+ ref = reference_dual_gemm_swiglu(a.float(), b1_fp.float(), b2_fp.float()).to(torch.float16)
115
+ err1 = (out1.float() - ref.float()).abs().max().item()
116
+ err15 = (out15.float() - ref.float()).abs().max().item()
117
+ print(f" Option 1 err: {err1:.2e}")
118
+ print(f" Option 1.5 err: {err15:.2e}")
119
+ print(f" Option 1 vs 1.5 diff: {(out1.float() - out15.float()).abs().max().item():.2e}")
example_dual_gemm_mxfp.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example: Dual GEMM + SwiGLU with MXFP4 weights (pre-dequant option 1).
4
+ Quantizes B1, B2 to MXFP4, then upcasts and runs the fused GEMM.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
10
+
11
+ import torch
12
+ from amd_dual_gemm_swiglu import dual_gemm_swiglu, reference_dual_gemm_swiglu
13
+ from numerics_details.mxfp_details._downcast_to_mxfp import _downcast_to_mxfp
14
+ from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
15
+
16
+ MXFP_BLOCK_SIZE_PY = 32
17
+
18
+
19
+ def downcast_fp16_to_mxfp4(src: torch.Tensor, block_m: int = 128, block_k: int = 64):
20
+ """fp16 [M,K] -> (mx_tensor [M,K//2], mx_scale [M,K//32])."""
21
+ assert block_k % MXFP_BLOCK_SIZE_PY == 0
22
+ M, K = src.shape
23
+ mx_tensor = torch.empty((M, K // 2), device=src.device, dtype=torch.uint8)
24
+ mx_scale = torch.empty((M, K // 32), device=src.device, dtype=torch.uint8)
25
+ grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
26
+ _downcast_to_mxfp[grid](
27
+ mx_tensor, mx_tensor.stride(0), 1,
28
+ mx_scale, mx_scale.stride(0), mx_scale.stride(1),
29
+ src, src.stride(0), src.stride(1),
30
+ M, K,
31
+ BLOCK_SIZE_OUT_DIM=block_m,
32
+ BLOCK_SIZE_QUANT_DIM=block_k,
33
+ DEQUANT_SCALE_ROUNDING_MODE=0,
34
+ )
35
+ return mx_tensor, mx_scale
36
+
37
+
38
+ def main():
39
+ if not torch.cuda.is_available():
40
+ print("No CUDA/ROCm device.")
41
+ return
42
+ device = "cuda"
43
+ print("Device:", torch.cuda.get_device_name(0))
44
+
45
+ M, N, K = 256, 128, 512
46
+ torch.manual_seed(42)
47
+ a = torch.randn((M, K), device=device, dtype=torch.float16) * 0.1
48
+ b1_fp16 = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
49
+ b2_fp16 = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
50
+
51
+ # Quantize B1, B2 to MXFP4 (need K, N multiples of 32 for block_k)
52
+ block_k = 64
53
+ b1_mx, b1_scale = downcast_fp16_to_mxfp4(b1_fp16, block_k=block_k)
54
+ b2_mx, b2_scale = downcast_fp16_to_mxfp4(b2_fp16, block_k=block_k)
55
+ print(f"Quantized B1: {b1_mx.shape}, {b1_scale.shape}")
56
+
57
+ # Run dual GEMM with MXFP4 weights (pre-dequant)
58
+ out_mxfp = dual_gemm_swiglu(a, (b1_mx, b1_scale), (b2_mx, b2_scale))
59
+ print(f"Output (MXFP4 path): {out_mxfp.shape}")
60
+
61
+ # Reference with fp16
62
+ out_ref = reference_dual_gemm_swiglu(a.float(), b1_fp16.float(), b2_fp16.float()).to(torch.float16)
63
+ err = (out_mxfp.float() - out_ref.float()).abs().max().item()
64
+ rel = err / (out_ref.float().abs().max().item() + 1e-6)
65
+ print(f"vs fp16 ref: max abs err={err:.2e}, rel={rel:.2e}")
66
+ print("Done.")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
example_moe_compaction_gemm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example: Mask compaction + Dual GEMM integration (MoE-style).
4
+ Before dual GEMM: compact (Yv, Yi) per row based on BitMask.
5
+ Then use compacted tensors for routing into expert weights.
6
+
7
+ ROCm note: tl.store with dynamic write_indx may fail on ROCm Triton.
8
+ If so, use the PyTorch fallback in mask_compaction.py.
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import torch
16
+
17
+ from mask_compaction import masked_compaction, masked_compaction_torch_fallback
18
+ from amd_dual_gemm_swiglu import dual_gemm_swiglu
19
+
20
+
21
+ def example_integration():
22
+ """Sketch: compact routing outputs, then run dual GEMM on routed experts."""
23
+ device = "cuda"
24
+ if not torch.cuda.is_available():
25
+ print("No GPU")
26
+ return
27
+
28
+ M, K, N = 256, 64, 128 # tokens, hidden, expert dim
29
+ top_k = 8
30
+ num_experts = 4
31
+
32
+ # Simulate routing: Yv [M, K] values, Yi [M, K] expert indices (0..num_experts-1)
33
+ torch.manual_seed(42)
34
+ Yv = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1
35
+ Yi = torch.randint(0, num_experts, (M, K), device=device, dtype=torch.int32)
36
+
37
+ # BitMask [M, ceil(K/32)]: 1 = use, 0 = discard (e.g. from load balance)
38
+ BitMask = torch.ones(M, (K + 31) // 32, device=device, dtype=torch.int32)
39
+ BitMask[:, 0] = 0x55555555 # example: alternating bits
40
+
41
+ # 1) Compact (Yv, Yi) per row based on BitMask
42
+ try:
43
+ RetYv, RetYi = masked_compaction(Yv, Yi, BitMask, sentinel=float("nan"))
44
+ print("Compaction: Triton kernel OK")
45
+ except Exception as e:
46
+ print(f"Compaction: Triton failed ({e}), using PyTorch fallback")
47
+ RetYv, RetYi = masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan"))
48
+
49
+ # 2) Use compacted indices for routing into expert weights
50
+ # Expert weights: B1[E,K,N], B2[E,K,N] or similar. For simplicity, flat GEMM:
51
+ # A = routed activations [M, K], B1/B2 = expert weights [K, N]
52
+ # This is a simplified sketch; real MoE has per-expert B.
53
+ B1 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
54
+ B2 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
55
+
56
+ # Use RetYv as activations (compacted); pad/truncate to [M, K] if needed
57
+ A = RetYv[:, :K].contiguous()
58
+ if A.shape[1] < K:
59
+ A = torch.nn.functional.pad(A, (0, K - A.shape[1]), value=0)
60
+
61
+ # 3) Dual GEMM
62
+ out = dual_gemm_swiglu(A, B1, B2)
63
+ print(f"Dual GEMM output: {out.shape}")
64
+ print("Done.")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ example_integration()
example_mxfp_roundtrip.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example: use _downcast_to_mxfp (fp16 -> MXFP4) and _upcast_from_mxfp (MXFP4 -> fp16)
4
+ for a round-trip on the remote server.
5
+
6
+ Usage on remote:
7
+ cd /root/kernels
8
+ python example_mxfp_roundtrip.py
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import torch
16
+ from numerics_details.mxfp_details._downcast_to_mxfp import _downcast_to_mxfp
17
+ from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
18
+
19
+ MXFP_BLOCK_SIZE_PY = 32 # Python int for checks (tl.constexpr in kernels)
20
+
21
+
22
+ def downcast_fp16_to_mxfp4(src: torch.Tensor, block_m: int = 128, block_k: int = 64):
23
+ """Convert fp16 tensor [M, K] to MXFP4 (uint8 mx_tensor + uint8 mx_scale)."""
24
+ assert src.dim() == 2 and src.dtype in (torch.float16, torch.bfloat16)
25
+ assert block_k % MXFP_BLOCK_SIZE_PY == 0, f"block_k must be multiple of {MXFP_BLOCK_SIZE_PY}"
26
+ M, K = src.shape
27
+
28
+ # Outputs: mx_tensor [M, K//2] uint8, mx_scale [M, K//32] uint8
29
+ mx_tensor = torch.empty((M, K // 2), device=src.device, dtype=torch.uint8)
30
+ mx_scale = torch.empty((M, K // 32), device=src.device, dtype=torch.uint8)
31
+
32
+ grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
33
+ _downcast_to_mxfp[grid](
34
+ mx_tensor, mx_tensor.stride(0), 1,
35
+ mx_scale, mx_scale.stride(0), mx_scale.stride(1),
36
+ src, src.stride(0), src.stride(1),
37
+ M, K,
38
+ BLOCK_SIZE_OUT_DIM=block_m,
39
+ BLOCK_SIZE_QUANT_DIM=block_k,
40
+ DEQUANT_SCALE_ROUNDING_MODE=0,
41
+ )
42
+ return mx_tensor, mx_scale
43
+
44
+
45
+ def main():
46
+ if not torch.cuda.is_available():
47
+ print("No CUDA/ROCm device.")
48
+ return
49
+ device = "cuda"
50
+ print("Device:", torch.cuda.get_device_name(0))
51
+
52
+ # Create random fp16 tensor
53
+ M, K = 256, 128
54
+ src = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1
55
+
56
+ # Downcast fp16 -> MXFP4
57
+ mx_tensor, mx_scale = downcast_fp16_to_mxfp4(src)
58
+ print(f"Downcast OK: mx_tensor {mx_tensor.shape}, mx_scale {mx_scale.shape}")
59
+
60
+ # Upcast MXFP4 -> fp16
61
+ recovered = upcast_mxfp4_to_fp16(mx_tensor, mx_scale)
62
+ print(f"Upcast OK: recovered {recovered.shape}")
63
+
64
+ # Compare
65
+ err = (src.float() - recovered.float()).abs().max().item()
66
+ rel = err / (src.float().abs().max().item() + 1e-6)
67
+ print(f"Round-trip max abs err: {err:.2e}, rel: {rel:.2e}")
68
+ print("Done.")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
hardware_submission.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ import os
5
+
6
+ # 1. HARDWARE DIAGNOSTICS & OS PREP
7
+ def check_environment():
8
+ cuda_avail = torch.cuda.is_available()
9
+ if cuda_avail:
10
+ # Optimization: Force kernel arguments to device to save PCIe latency
11
+ os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
12
+ # Disable compiler cache for benchmarking clean runs
13
+ os.environ["TRITON_CACHE_DIR"] = ""
14
+
15
+ check_environment()
16
+
17
+ # 2. THE SOL KERNEL
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
21
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
22
+ ],
23
+ key=['M', 'N', 'K'],
24
+ )
25
+ @triton.jit
26
+ def dual_gemm_hardware_kernel(
27
+ a_ptr, b1_ptr, b2_ptr, c_ptr,
28
+ sfa_ptr, sfb1_ptr, sfb2_ptr,
29
+ M, N, K, L,
30
+ stride_am, stride_ak, stride_al,
31
+ stride_bn, stride_bk, stride_bl,
32
+ stride_cm, stride_cn, stride_cl,
33
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
34
+ GROUP_SIZE_M: tl.constexpr,
35
+ ):
36
+ # Persistent Grid logic
37
+ pid = tl.program_id(0)
38
+ num_pid_m = tl.cdiv(M, BLOCK_M)
39
+ num_pid_n = tl.cdiv(N, BLOCK_N)
40
+ total_tiles = num_pid_m * num_pid_n * L
41
+
42
+ for tile_idx in tl.range(pid, total_tiles, tl.num_programs(0)):
43
+ l_idx = tile_idx // (num_pid_m * num_pid_n)
44
+ tile_rem = tile_idx % (num_pid_m * num_pid_n)
45
+
46
+ # Swizzle for L2 Locality
47
+ pid_m, pid_n = tl.swizzle2d(tile_rem, num_pid_m, num_pid_n, GROUP_SIZE_M)
48
+
49
+ # Base offsets
50
+ rm = pid_m * BLOCK_M
51
+ rn = pid_n * BLOCK_N
52
+
53
+ # Ranges
54
+ offs_m = rm + tl.arange(0, BLOCK_M)
55
+ offs_n = rn + tl.arange(0, BLOCK_N)
56
+ offs_k = tl.arange(0, BLOCK_K)
57
+
58
+ # Memory Pointers
59
+ a_ptrs = a_ptr + l_idx * stride_al + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
60
+ b1_ptrs = b1_ptr + l_idx * stride_bl + (offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk)
61
+ b2_ptrs = b2_ptr + l_idx * stride_bl + (offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk)
62
+
63
+ # Accumulators
64
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
65
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
66
+
67
+ # Scale Factor Load Offsets (Assuming OCP 1 scale per 16 elements)
68
+ # sfa: (M, K/16, L), sfb: (N, K/16, L)
69
+ sfa_base = sfa_ptr + l_idx * (M * (K // 16)) + (offs_m[:, None] * (K // 16))
70
+ sfb1_base = sfb1_ptr + l_idx * (N * (K // 16)) + (offs_n[None, :] * (K // 16))
71
+ sfb2_base = sfb2_ptr + l_idx * (N * (K // 16)) + (offs_n[None, :] * (K // 16))
72
+
73
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
74
+ # 1. Load Data
75
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_K), other=0.0)
76
+ b1 = tl.load(b1_ptrs, mask=(offs_n[None, :] < N) & (offs_k[:, None] < K - k * BLOCK_K), other=0.0)
77
+ b2 = tl.load(b2_ptrs, mask=(offs_n[None, :] < N) & (offs_k[:, None] < K - k * BLOCK_K), other=0.0)
78
+
79
+ # 2. Load Scales for current K-block
80
+ # Blackwell uses a 32x4 atom, but for pointers, we load the K-slice
81
+ curr_sfa = tl.load(sfa_base + (k * (BLOCK_K // 16)), mask=(offs_m[:, None] < M), other=1.0)
82
+ curr_sfb1 = tl.load(sfb1_base + (k * (BLOCK_K // 16)), mask=(offs_n[None, :] < N), other=1.0)
83
+ curr_sfb2 = tl.load(sfb2_base + (k * (BLOCK_K // 16)), mask=(offs_n[None, :] < N), other=1.0)
84
+
85
+ # 3. Hardware DOT Scaled
86
+ acc1 = tl.dot_scaled(a, curr_sfa, "e2m1", b1, curr_sfb1, "e2m1", acc1)
87
+ acc2 = tl.dot_scaled(a, curr_sfa, "e2m1", b2, curr_sfb2, "e2m1", acc2)
88
+
89
+ # Advance K
90
+ a_ptrs += BLOCK_K * stride_ak
91
+ b1_ptrs += BLOCK_K * stride_bk
92
+ b2_ptrs += BLOCK_K * stride_bk
93
+
94
+ # 4. Epilogue (Fused SiLU + Gating)
95
+ res1 = acc1.to(tl.float16)
96
+ activated = res1 * tl.sigmoid(res1)
97
+ final_out = activated * acc2.to(tl.float16)
98
+
99
+ # 5. Masked Store
100
+ c_ptrs = c_ptr + l_idx * stride_cl + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
101
+ tl.store(c_ptrs, final_out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
102
+
103
+ # 3. SUBMISSION INTERFACE
104
+ def dual_gemm_submission(data):
105
+ a, b1, b2, sfa, sfb1, sfb2, c = data
106
+ M, K_packed, L = a.shape
107
+ N = b1.shape[0]
108
+ K = K_packed * 2 # FP4 2nd element expansion
109
+
110
+ # Saturate Device (148 for B200, 304 for MI300X)
111
+ num_sms = torch.cuda.get_device_properties(0).multi_processor_count
112
+ grid = (num_sms,)
113
+
114
+ dual_gemm_hardware_kernel[grid](
115
+ a, b1, b2, c, sfa, sfb1, sfb2,
116
+ M, N, K, L,
117
+ a.stride(0), a.stride(1), a.stride(2),
118
+ b1.stride(0), b1.stride(1), b1.stride(2),
119
+ c.stride(0), c.stride(1), c.stride(2)
120
+ )
121
+ return c
mask_compaction.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Masked compaction kernel: compact (Yv, Yi) per row based on BitMask.
3
+ Active elements (bit=1) move to front, inactive (bit=0) move to back with sentinel.
4
+ For MoE: use before dual GEMM to get dense top-k for routing into expert weights.
5
+
6
+ ROCm note: tl.store with dynamic write_indx may have limitations. If it fails,
7
+ fall back to PyTorch compaction.
8
+ """
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit
16
+ def _masked_compaction(
17
+ Yv, Yi, BitMask, stride_bm, stride_bn,
18
+ RetYv, RetYi, sentinel, K: tl.constexpr
19
+ ):
20
+ pid_m = tl.program_id(0)
21
+ yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
22
+ yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
23
+ div = yi // 32
24
+ rem = yi % 32
25
+ active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
26
+ exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
27
+ active_flags = active_bits.to(tl.int1)
28
+ rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
29
+ write_indx = exc_cumsum + rev_arange
30
+ yv = tl.where(active_flags, yv, sentinel)
31
+ yi = tl.where(active_flags, yi, sentinel)
32
+ tl.store(RetYv + pid_m * K + write_indx, yv)
33
+ tl.store(RetYi + pid_m * K + write_indx, yi)
34
+
35
+
36
+ def masked_compaction(
37
+ Yv: torch.Tensor, # [M, K] values
38
+ Yi: torch.Tensor, # [M, K] indices (int32)
39
+ BitMask: torch.Tensor, # [M, ceil(K/32)] or similar - 1 bit per position
40
+ sentinel: float = float("nan"),
41
+ ) -> tuple[torch.Tensor, torch.Tensor]:
42
+ """
43
+ Compact Yv, Yi per row: active (BitMask=1) to front, inactive to back with sentinel.
44
+ Returns (RetYv, RetYi) same shape as (Yv, Yi).
45
+ """
46
+ M, K = Yv.shape
47
+ assert Yi.shape == (M, K)
48
+ RetYv = torch.empty_like(Yv)
49
+ RetYi = torch.empty_like(Yi)
50
+ grid = (M,)
51
+ _masked_compaction[grid](
52
+ Yv, Yi, BitMask,
53
+ BitMask.stride(0), BitMask.stride(1),
54
+ RetYv, RetYi, sentinel, K=K,
55
+ )
56
+ return RetYv, RetYi
57
+
58
+
59
+ def masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan")):
60
+ """PyTorch fallback if Triton kernel fails on ROCm."""
61
+ M, K = Yv.shape
62
+ RetYv = torch.full_like(Yv, sentinel)
63
+ RetYi = torch.full_like(Yi, -1)
64
+ for m in range(M):
65
+ # Bit per position k: div=k//32, rem=k%32
66
+ div = torch.arange(K, device=Yv.device) // 32
67
+ rem = torch.arange(K, device=Yv.device) % 32
68
+ active = ((BitMask[m, div] >> rem) & 1).bool()
69
+ n_active = active.sum().item()
70
+ RetYv[m, :n_active] = Yv[m, active]
71
+ RetYi[m, :n_active] = Yi[m, active]
72
+ return RetYv, RetYi
73
+
74
+
75
+ masked_compaction_pytorch = masked_compaction_torch_fallback # alias for import
numerics_details/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # numerics_details: MXFP only (use mxfp_details)
2
+ from .mxfp_details import upcast_mxfp4_to_fp16
3
+
4
+ __all__ = ["upcast_mxfp4_to_fp16"]
numerics_details/mxfp_details/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # mxfp_details: MXFP quantize/dequantize kernels
2
+ from .upcast_mxfp4 import upcast_mxfp4_to_fp16
3
+
4
+ __all__ = ["upcast_mxfp4_to_fp16"]
numerics_details/mxfp_details/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (240 Bytes). View file
 
numerics_details/mxfp_details/__pycache__/_downcast_to_mxfp.cpython-312.pyc ADDED
Binary file (8.91 kB). View file
 
numerics_details/mxfp_details/__pycache__/_upcast_from_mxfp.cpython-312.pyc ADDED
Binary file (7.88 kB). View file
 
numerics_details/mxfp_details/__pycache__/upcast_mxfp4.cpython-312.pyc ADDED
Binary file (5.19 kB). View file
 
numerics_details/mxfp_details/_downcast_to_mxfp.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://huggingface.co/kernels-community/triton-kernels/blob/main/build/torch-cuda/numerics_details/mxfp_details/_downcast_to_mxfp.py
2
+
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ # fmt: off
7
+
8
+
9
+ MXFP_BLOCK_SIZE = tl.constexpr(32)
10
+
11
+
12
+ @triton.jit
13
+ def _get_max_quant_val(dtype: tl.constexpr):
14
+ if dtype == tl.uint8:
15
+ return 6.0
16
+ elif dtype == tl.float8e5:
17
+ return 57344.0
18
+ elif dtype == tl.float8e4nv:
19
+ return 448.0
20
+ else:
21
+ tl.static_assert(False, f"Invalid {dtype=}")
22
+
23
+
24
+ @triton.jit
25
+ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
26
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
27
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
28
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
29
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
30
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
31
+
32
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
33
+ f32_tensor = src_tensor.to(tl.float32)
34
+ abs_tensor = tl.abs(f32_tensor)
35
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
36
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
37
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
38
+ dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
39
+ if DEQUANT_SCALE_ROUNDING_MODE == 0:
40
+ # DequantScaleRoundingMode.ROUND_UP
41
+ # compute 2 ** ceil(log2(dequant_scale))
42
+ # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
43
+ # A corner case: exponent is 0xFF that will overflow but that's already
44
+ # NaN so assume we don't care.
45
+ dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
46
+ else:
47
+ # DequantScaleRoundingMode.ROUND_DOWN
48
+ # compute 2 ** floor(log2(dequant_scale))
49
+ assert DEQUANT_SCALE_ROUNDING_MODE == 1
50
+ dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
51
+ dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
52
+ quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
53
+
54
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
55
+ quant_tensor = f32_tensor * quant_scale
56
+
57
+ # Reshape the tensors after scaling
58
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
59
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
60
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
61
+ dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
62
+
63
+ # First, we simply extract the exponent part of the scales and store the result
64
+ dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
65
+ # Now we must convert the tensors to the mx format.
66
+ if is_fp8:
67
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
68
+ else:
69
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
70
+ signs = quant_tensor & 0x80000000
71
+ exponents = (quant_tensor >> 23) & 0xFF
72
+ mantissas = (quant_tensor & 0x7FFFFF)
73
+
74
+ # 0.25 <= x < 0.75 maps to 0.5, a denormal number
75
+ E8_BIAS = 127
76
+ E2_BIAS = 1
77
+ # Move implicit bit 1 at the beginning to mantissa for denormals
78
+ # tl.core.sub not available in Triton ROCm; use plain subtraction
79
+ adjusted_exponents = E8_BIAS - (exponents + 1)
80
+ mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
81
+
82
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
83
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
84
+
85
+ # Combine sign, exponent, and mantissa, while saturating
86
+ # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
87
+ e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
88
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
89
+
90
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
91
+ evens, odds = tl.split(e2m1_value)
92
+ out_tensor = evens | (odds << 4)
93
+
94
+ return out_tensor, dequant_scale_exponent
95
+
96
+
97
+ @triton.jit
98
+ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
99
+ mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
100
+ src_ptr, stride_src_outer, stride_src_quant,
101
+ outer_dim, quant_dim,
102
+ BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
103
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
104
+
105
+ tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
106
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
107
+
108
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
109
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
110
+ tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
111
+ f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
112
+
113
+ src_dtype: tl.constexpr = src_ptr.dtype.element_ty
114
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
115
+ tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16")
116
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
117
+
118
+ outer_block = tl.program_id(0).to(tl.int64)
119
+ quant_block = tl.program_id(1).to(tl.int64)
120
+
121
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
122
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
123
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
124
+
125
+ start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
126
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
127
+ start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
128
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
129
+
130
+ src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
131
+ mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
132
+ mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
133
+
134
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
135
+ offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
136
+ offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
137
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
138
+
139
+ mask_src_quant = start_src_quant + offs_src_quant < quant_dim
140
+ mask_n = start_out + offs_outer < outer_dim
141
+ full_mask_src = mask_src_quant & mask_n
142
+
143
+ mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
144
+ full_mask_mxt = mask_mxt_quant & mask_n
145
+
146
+ scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
147
+ full_scale_mask = scale_mask_k & mask_n
148
+
149
+ src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
150
+ mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
151
+ mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
152
+ src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
153
+
154
+ out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
155
+ DEQUANT_SCALE_ROUNDING_MODE)
156
+
157
+ tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
158
+ tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
159
+
160
+
161
+ @triton.jit(repr=lambda _: "_dequantize_mxfp8")
162
+ def _dequantize_mxfp8_fn(input, mask, pid=None):
163
+ return _compute_quant_and_scale(input, mask, tl.float8e4nv)
numerics_details/mxfp_details/_upcast_from_mxfp.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
4
+
5
+
6
+ # fmt: off
7
+ @triton.jit
8
+ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
9
+ stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
10
+ outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
11
+
12
+ tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
13
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
14
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
15
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
16
+ dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
17
+ tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16)
18
+ tl.static_assert(
19
+ mx_tensor_dtype == tl.uint8
20
+ or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
21
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
22
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
23
+
24
+ # Determine if we are dealing with fp8 types.
25
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
26
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
27
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
28
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
29
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
30
+
31
+ # Compute starting indices for the quantized (packed) dimension and the outer dimension.
32
+ outer_block = tl.program_id(0).to(tl.int64)
33
+ quant_block = tl.program_id(1).to(tl.int64)
34
+
35
+ start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
36
+ start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
37
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
38
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
39
+
40
+ mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
41
+ mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
42
+ out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
43
+
44
+ # Compute offsets and masks.
45
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
46
+ offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
47
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
48
+ offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
49
+
50
+ mask_outer = start_out + offs_outer < outer_dim
51
+ mask_out_quant = start_out_quant + offs_out_quant < quant_dim
52
+ full_mask_out = mask_out_quant & mask_outer
53
+
54
+ mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
55
+ full_mask_src = mask_src_quant & mask_outer
56
+
57
+ mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
58
+ full_scale_mask = mask_scale & mask_outer
59
+
60
+ tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
61
+ scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
62
+ out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
63
+
64
+ # Load the packed tensor and scale.
65
+ tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
66
+ scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
67
+
68
+ # Upcast the scale to the destination type.
69
+ if dst_dtype == tl.bfloat16:
70
+ dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
71
+ else:
72
+ tl.static_assert(dst_dtype == tl.float16)
73
+ dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
74
+ dst_scale = dst_scale.to(tl.float16)
75
+
76
+ # Now upcast the tensor.
77
+ if is_fp8:
78
+ dst_tensor = tensor.to(dst_dtype)
79
+ if mx_tensor_dtype == tl.float8e5:
80
+ from_e_bits: tl.constexpr = 5
81
+ from_m_bits: tl.constexpr = 2
82
+ to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5
83
+ to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
84
+
85
+ # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
86
+ non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
87
+ non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
88
+ dst_tensor = tl.where(
89
+ (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
90
+ (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True),
91
+ dst_tensor,
92
+ )
93
+ else:
94
+ tl.static_assert(is_fp4)
95
+ dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15
96
+ dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800
97
+ dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
98
+ # e2m1
99
+ em0 = tensor & 0x07
100
+ em1 = tensor & 0x70
101
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
102
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
103
+ # Three cases:
104
+ # 1) x is normal and non-zero: Correct bias
105
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
106
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
107
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
108
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
109
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
110
+ # 3) x is zero, do nothing
111
+ # Interleave x0,x1: use tl.where (ROCm tl.cat only supports 1D)
112
+ idx_k = tl.arange(0, BLOCK_SIZE_QUANT_DIM)
113
+ is_even = (idx_k % 2) == 0
114
+ val_x0 = x0[:, idx_k // 2]
115
+ val_x1 = x1[:, idx_k // 2]
116
+ dst_tensor = tl.where(is_even[None, :], val_x0, val_x1).to(dst_dtype, bitcast=True)
117
+
118
+ # dst_tensor already [M, K/32, 32] for fp4; scale was stored with 32-sized "inner" grouping
119
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
120
+ scale = scale.reshape(dst_scale.shape)
121
+
122
+ out_tensor = dst_tensor * dst_scale
123
+ # Correct any NaNs encoded via the scale.
124
+ out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
125
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
126
+ tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
numerics_details/mxfp_details/upcast_mxfp4.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reusable MXFP4 upcast: MXFP4 (uint8 mx_tensor + uint8 mx_scale) -> fp16/bf16.
3
+ Uses Triton kernel when available; falls back to PyTorch on ROCm (tl.cat limitation).
4
+ """
5
+
6
+ import torch
7
+
8
+ from ._upcast_from_mxfp import _upcast_from_mxfp
9
+
10
+ try:
11
+ from triton.compiler.errors import CompilationError
12
+ except ImportError:
13
+ CompilationError = Exception
14
+
15
+ MXFP_BLOCK_SIZE_PY = 32
16
+
17
+
18
+ def _upcast_mxfp4_to_fp16_pytorch(
19
+ mx_tensor: torch.Tensor, mx_scale: torch.Tensor, dtype: torch.dtype = torch.float16
20
+ ) -> torch.Tensor:
21
+ """PyTorch fallback (used when Triton kernel fails on ROCm)."""
22
+ M, K_half = mx_tensor.shape
23
+ K = K_half * 2
24
+ dst_bias = 15
25
+ dst_0p5 = 0x3800
26
+ dst_m_bits = 10
27
+
28
+ tensor = mx_tensor.to(torch.int32)
29
+ em0 = tensor & 0x07
30
+ em1 = tensor & 0x70
31
+ x0 = (em0 << (dst_m_bits - 1)) | ((tensor & 0x08) << 12)
32
+ x1 = (em1 << (dst_m_bits - 5)) | ((tensor & 0x80) << 8)
33
+
34
+ x0 = torch.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
35
+ x1 = torch.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
36
+ x0 = torch.where(em0 == 0x01, torch.full_like(x0, dst_0p5) | (x0 & 0x8000), x0)
37
+ x1 = torch.where(em1 == 0x10, torch.full_like(x1, dst_0p5) | (x1 & 0x8000), x1)
38
+
39
+ out_u16 = torch.empty((M, K), device=mx_tensor.device, dtype=torch.uint16)
40
+ out_u16[:, 0::2] = (x0 & 0xFFFF).to(torch.uint16)
41
+ out_u16[:, 1::2] = (x1 & 0xFFFF).to(torch.uint16)
42
+ dst_tensor = out_u16.view(dtype)
43
+
44
+ scale_u32 = mx_scale.to(torch.int32) << 23
45
+ dst_scale = scale_u32.view(torch.float32).to(dtype)
46
+ dst_scale = dst_scale.unsqueeze(-1).repeat(1, 1, 32).reshape(M, K)
47
+
48
+ out_tensor = dst_tensor * dst_scale
49
+ out_tensor = torch.where(
50
+ mx_scale.unsqueeze(-1).expand(-1, -1, 32).reshape(M, K) == 0xFF,
51
+ float("nan"),
52
+ out_tensor,
53
+ )
54
+ return out_tensor
55
+
56
+
57
+ def upcast_mxfp4_to_fp16(
58
+ mx_tensor: torch.Tensor,
59
+ mx_scale: torch.Tensor,
60
+ block_m: int = 128,
61
+ block_k: int = 64,
62
+ dtype: torch.dtype = torch.float16,
63
+ verbose: bool = False,
64
+ ) -> torch.Tensor:
65
+ """Convert MXFP4 [M,K/2]+[M,K/32] -> fp16/bf16 [M,K]. Falls back to PyTorch if Triton fails."""
66
+ assert mx_tensor.dim() == 2 and mx_tensor.dtype == torch.uint8
67
+ assert mx_scale.dim() == 2 and mx_scale.dtype == torch.uint8
68
+ M = mx_tensor.shape[0]
69
+ K = mx_tensor.shape[1] * 2
70
+ assert mx_scale.shape == (M, K // 32)
71
+ assert block_k % MXFP_BLOCK_SIZE_PY == 0
72
+
73
+ try:
74
+ out = torch.empty((M, K), device=mx_tensor.device, dtype=dtype)
75
+ grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
76
+ _upcast_from_mxfp[grid](
77
+ out, out.stride(0), 1,
78
+ mx_scale, mx_scale.stride(0), mx_scale.stride(1),
79
+ mx_tensor, mx_tensor.stride(0), 1,
80
+ M, K,
81
+ BLOCK_SIZE_OUT_DIM=block_m,
82
+ BLOCK_SIZE_QUANT_DIM=block_k,
83
+ )
84
+ return out
85
+ except CompilationError:
86
+ if verbose:
87
+ print("Triton upcast failed (e.g. ROCm), using PyTorch fallback.")
88
+ return _upcast_mxfp4_to_fp16_pytorch(mx_tensor, mx_scale, dtype)
submission.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ # 1. HARDWARE DIAGNOSTICS
6
+ def check_environment():
7
+ print(f"--- Environment Check ---")
8
+ cuda_avail = torch.cuda.is_available()
9
+ print(f"Is CUDA/ROCm available? {cuda_avail}")
10
+
11
+ if cuda_avail:
12
+ device_name = torch.cuda.get_device_name(0)
13
+ print(f"GPU Detected: {device_name}")
14
+
15
+ # Check for Blackwell (SM 10.0) or MI300X (gfx942)
16
+ prop = torch.cuda.get_device_properties(0)
17
+ if hasattr(prop, 'major'):
18
+ print(f"Compute Capability: {prop.major}.{prop.minor}")
19
+ else:
20
+ print("No NVIDIA/AMD GPU detected. Triton kernels will not run on this hardware.")
21
+ print(f"-------------------------\n")
22
+
23
+ # Call diagnostic immediately on import
24
+ check_environment()
25
+
26
+ # 2. PLACEHOLDER KERNEL (Logic from previous steps)
27
+ @triton.jit
28
+ def dual_gemm_kernel(a_ptr, b1_ptr, b2_ptr, c_ptr, M, N, K, **meta):
29
+ # Kernel code here...
30
+ pass
31
+
32
+ # 3. HARNESS INTERFACE
33
+ def dual_gemm_submission(data):
34
+ # This is what the leaderboard/benchmark calls
35
+ a, b1, b2, sfa, sfb1, sfb2, c = data
36
+ # ... launch logic ...
37
+ return c
test_mxfp.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal test for MXFP _downcast_to_mxfp on MI300X (ROCm).
4
+ Tests fp16 -> uint8 (fp4 packed) path; float8 path may not work on ROCm yet.
5
+ Run on remote: cd /root/kernels && python test_mxfp.py
6
+ """
7
+
8
+ import sys
9
+ import os
10
+
11
+ # Allow imports from /root/kernels
12
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
13
+ if _script_dir not in sys.path:
14
+ sys.path.insert(0, _script_dir)
15
+
16
+ def main():
17
+ print("=== MXFP Import Test ===")
18
+ try:
19
+ from numerics_details.mxfp_details._downcast_to_mxfp import (
20
+ _downcast_to_mxfp,
21
+ _compute_quant_and_scale,
22
+ MXFP_BLOCK_SIZE,
23
+ )
24
+ print(" Import OK: _downcast_to_mxfp, _compute_quant_and_scale, MXFP_BLOCK_SIZE")
25
+ except Exception as e:
26
+ print(f" Import FAILED: {e}")
27
+ return 1
28
+
29
+ print("\n=== Triton + CUDA/ROCm Check ===")
30
+ import torch
31
+ if not torch.cuda.is_available():
32
+ print(" No GPU available. Skipping kernel test.")
33
+ return 0
34
+ print(f" Device: {torch.cuda.get_device_name(0)}")
35
+
36
+ import triton
37
+ import triton.language as tl
38
+
39
+ print("\n=== MXFP Downcast Test (fp16 -> fp4 uint8) ===")
40
+ # Use fp4 path (uint8 output) - avoids float8 dtypes which may lack ROCm support
41
+ BLOCK_SIZE_OUT_DIM = 64
42
+ BLOCK_SIZE_QUANT_DIM = 64 # must be multiple of 32
43
+ outer_dim = 128
44
+ quant_dim = 128
45
+ DEQUANT_SCALE_ROUNDING_MODE = 0
46
+
47
+ device = "cuda"
48
+ src = torch.randn(outer_dim, quant_dim, device=device, dtype=torch.float16) * 0.1
49
+
50
+ # Output shapes for fp4 (uint8): mx_tensor [outer, quant//2], mx_scale [outer, quant//32]
51
+ mx_tensor = torch.empty(outer_dim, quant_dim // 2, device=device, dtype=torch.uint8)
52
+ mx_scale = torch.empty(outer_dim, quant_dim // 32, device=device, dtype=torch.uint8)
53
+
54
+ num_outer_blocks = (outer_dim + BLOCK_SIZE_OUT_DIM - 1) // BLOCK_SIZE_OUT_DIM
55
+ num_quant_blocks = (quant_dim + BLOCK_SIZE_QUANT_DIM - 1) // BLOCK_SIZE_QUANT_DIM
56
+ grid = (num_outer_blocks, num_quant_blocks)
57
+
58
+ try:
59
+ _downcast_to_mxfp[grid](
60
+ mx_tensor,
61
+ src.stride(0), 1,
62
+ mx_scale,
63
+ mx_scale.stride(0), mx_scale.stride(1),
64
+ src,
65
+ src.stride(0), src.stride(1),
66
+ outer_dim, quant_dim,
67
+ BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
68
+ BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
69
+ DEQUANT_SCALE_ROUNDING_MODE=DEQUANT_SCALE_ROUNDING_MODE,
70
+ )
71
+ torch.cuda.synchronize()
72
+ print(" Kernel launch OK")
73
+ print(f" mx_tensor shape: {mx_tensor.shape}, dtype: {mx_tensor.dtype}")
74
+ print(f" mx_scale shape: {mx_scale.shape}")
75
+ print(f" mx_tensor sample (first row): {mx_tensor[0, :8].tolist()}")
76
+ except Exception as e:
77
+ print(f" Kernel FAILED: {e}")
78
+ import traceback
79
+ traceback.print_exc()
80
+ return 1
81
+
82
+ print("\n=== Done ===")
83
+ return 0
84
+
85
+
86
+ if __name__ == "__main__":
87
+ sys.exit(main())
testing.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Testing utilities matching triton-kernels/build/torch-cuda/testing.py.
3
+ https://huggingface.co/kernels-community/triton-kernels/blob/main/build/torch-cuda/testing.py
4
+ """
5
+
6
+ import enum
7
+ import functools
8
+ import os
9
+ import subprocess
10
+ import sys
11
+
12
+ import torch
13
+
14
+ # Numerics constants - use triton_kernels.numerics when available
15
+ try:
16
+ from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
17
+ except ImportError:
18
+ # Standalone fallback: standard float8 max finite values
19
+ MAX_FINITE_FLOAT8E5 = 57344.0 # float8 e5m2
20
+ MAX_FINITE_FLOAT8E4NV = 448.0 # float8 e4m3fn
21
+ MAX_FINITE_FLOAT8E4B8 = 448.0 # float8 e4m3fnuz
22
+
23
+
24
+ def assert_equal(ref, tri):
25
+ if isinstance(ref, torch.Tensor):
26
+ assert torch.all(ref == tri)
27
+ else:
28
+ assert ref == tri
29
+
30
+
31
+ def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
32
+ """
33
+ Compare reference values against obtained values.
34
+ """
35
+ if tri.dtype.itemsize == 1:
36
+ ref_as_type = ref.to(tri.dtype)
37
+ if ref.dtype == tri.dtype:
38
+ assert torch.all(ref_as_type == tri)
39
+ return
40
+ ref = ref_as_type
41
+
42
+ if maxtol is None:
43
+ maxtol = 2e-2
44
+ if rmstol is None:
45
+ rmstol = 4e-3
46
+
47
+ # cast to float32:
48
+ ref = ref.to(torch.float32).detach()
49
+ tri = tri.to(torch.float32).detach()
50
+ assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
51
+
52
+ # deal with infinite elements:
53
+ inf_mask_ref = torch.isinf(ref)
54
+ inf_mask_tri = torch.isinf(tri)
55
+ assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
56
+ refn = torch.where(inf_mask_ref, 0, ref)
57
+ trin = torch.where(inf_mask_tri, 0, tri)
58
+
59
+ # normalise so that RMS calculation doesn't overflow:
60
+ eps = 1.0e-30
61
+ multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
62
+ refn *= multiplier
63
+ trin *= multiplier
64
+
65
+ ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
66
+
67
+ rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
68
+ max_err = torch.max(rel_err).item()
69
+ rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
70
+
71
+ if verbose:
72
+ print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
73
+ print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
74
+
75
+ if max_err > maxtol:
76
+ bad_idxs = torch.nonzero(rel_err > maxtol)
77
+ num_nonzero = bad_idxs.size(0)
78
+ bad_idxs = bad_idxs[:1000]
79
+ print("%d / %d mismatched elements (shape = %s) at coords %s" %
80
+ (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
81
+
82
+ bad_idxs = bad_idxs.unbind(-1)
83
+ print("ref values: ", ref[tuple(bad_idxs)].cpu())
84
+ print("tri values: ", tri[tuple(bad_idxs)].cpu())
85
+
86
+ assert max_err <= maxtol
87
+ assert rms_err <= rmstol
88
+
89
+
90
+ class ComputeSanitizerTool(enum.Enum):
91
+ MEMCHECK = "memcheck"
92
+ RACECHECK = "racecheck"
93
+ SYNCCHECK = "synccheck"
94
+ INITCHECK = "initcheck"
95
+
96
+
97
+ def compute_sanitizer(**target_kwargs):
98
+ """
99
+ Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
100
+ to expose potential memory access errors.
101
+ This decorator requires the `request` fixture to be present.
102
+ If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
103
+ Running tests under compute sanitizer requires launching subprocess and is slow,
104
+ so use sparingly
105
+ """
106
+
107
+ def decorator(test_fn):
108
+
109
+ @functools.wraps(test_fn)
110
+ def wrapper(*args, **kwargs):
111
+ if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
112
+ test_fn(*args, **kwargs)
113
+ return
114
+
115
+ import psutil
116
+
117
+ if target_kwargs.pop("clear_torch_cache", False):
118
+ # If we don't pop clear_torch_cache, it won't pass
119
+ # target_kwargs.items() <= kwargs.items() condition below.
120
+ torch.cuda.empty_cache()
121
+ tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
122
+ assert isinstance(tools_to_check, list), f"{tools_to_check=}"
123
+ assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
124
+ f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
125
+
126
+ ppid_name = psutil.Process(os.getppid()).exe()
127
+ run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
128
+ if "run_sanitizer" in kwargs:
129
+ run_compute_sanitizer &= kwargs["run_sanitizer"]
130
+ if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
131
+ for tool in tools_to_check:
132
+ path = os.path.realpath(test_fn.__globals__["__file__"])
133
+ # get path of current file
134
+ env = {
135
+ "PATH": os.environ["PATH"],
136
+ "PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
137
+ "TORCH_SHOW_CPP_STACKTRACES": "1",
138
+ "CUDA_LAUNCH_BLOCKING": "1",
139
+ }
140
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
141
+ env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
142
+ assert "request_fixture" in kwargs, (
143
+ "memcheck'ed test must have a (possibly unused) `request` fixture")
144
+ test_id = kwargs["request_fixture"].node.callspec.id
145
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
146
+ cmd = [
147
+ "compute-sanitizer",
148
+ "--target-processes=application-only",
149
+ "--destroy-on-device-error=context",
150
+ f"--tool={tool.value}",
151
+ sys.executable,
152
+ "-m",
153
+ "pytest",
154
+ "-vsx",
155
+ cmd,
156
+ ]
157
+ for opt in ["--update_checksum", "--ignore_checksum_error"]:
158
+ if opt in sys.argv:
159
+ cmd.append(opt)
160
+ out = subprocess.run(
161
+ cmd,
162
+ stdout=subprocess.PIPE,
163
+ stderr=subprocess.STDOUT,
164
+ env=env,
165
+ )
166
+ sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
167
+ out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
168
+ test_output = out.stdout
169
+ if type(test_output) is bytes:
170
+ test_output = test_output.decode()
171
+
172
+ fail = False
173
+ if not sanitizer_ok:
174
+ print("compute-sanitizer returned an error")
175
+ fail = True
176
+ elif out.returncode != 0:
177
+ print(
178
+ "The test failed due to some other reason: consider running without compute-sanitizer to verify."
179
+ )
180
+ print(f"{out.returncode=}")
181
+ fail = True
182
+
183
+ if fail:
184
+ print("*****************************************************")
185
+ print("******************** TEST OUTPUT ********************")
186
+ print("*****************************************************")
187
+ print(test_output)
188
+ print("*****************************************************")
189
+ print("****************** TEST OUTPUT END ******************")
190
+ print("*****************************************************")
191
+ assert None
192
+ else:
193
+ test_fn(*args, **kwargs)
194
+
195
+ return wrapper
196
+
197
+ return decorator
198
+
199
+
200
+ def compute_actual_scale(x, dtype):
201
+ max_finite = {
202
+ torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
203
+ torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
204
+ torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
205
+ }[dtype]
206
+ return x.abs().max() / max_finite