Add ROCm dual GEMM, MXFP4, mask compaction, group GEMM
Browse files- AMDmi300asubmission.py +118 -0
- amd_dual_gemm_swiglu.py +318 -0
- dual_gemm_swiglu_full.py +417 -0
- dual_gemm_swiglu_mxfp.py +119 -0
- example_dual_gemm_mxfp.py +70 -0
- example_moe_compaction_gemm.py +68 -0
- example_mxfp_roundtrip.py +72 -0
- hardware_submission.py +121 -0
- mask_compaction.py +75 -0
- numerics_details/__init__.py +4 -0
- numerics_details/mxfp_details/__init__.py +4 -0
- numerics_details/mxfp_details/__pycache__/__init__.cpython-312.pyc +0 -0
- numerics_details/mxfp_details/__pycache__/_downcast_to_mxfp.cpython-312.pyc +0 -0
- numerics_details/mxfp_details/__pycache__/_upcast_from_mxfp.cpython-312.pyc +0 -0
- numerics_details/mxfp_details/__pycache__/upcast_mxfp4.cpython-312.pyc +0 -0
- numerics_details/mxfp_details/_downcast_to_mxfp.py +163 -0
- numerics_details/mxfp_details/_upcast_from_mxfp.py +126 -0
- numerics_details/mxfp_details/upcast_mxfp4.py +88 -0
- submission.py +37 -0
- test_mxfp.py +87 -0
- testing.py +206 -0
AMDmi300asubmission.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 1. HARDWARE DIAGNOSTICS
|
| 7 |
+
def check_environment():
|
| 8 |
+
print(f"--- Environment Check ---")
|
| 9 |
+
cuda_avail = torch.cuda.is_available()
|
| 10 |
+
print(f"Is CUDA/ROCm available? {cuda_avail}")
|
| 11 |
+
|
| 12 |
+
if cuda_avail:
|
| 13 |
+
device_name = torch.cuda.get_device_name(0)
|
| 14 |
+
print(f"GPU Detected: {device_name}")
|
| 15 |
+
prop = torch.cuda.get_device_properties(0)
|
| 16 |
+
if hasattr(prop, 'major'):
|
| 17 |
+
print(f"Compute Capability: {prop.major}.{prop.minor}")
|
| 18 |
+
# Optimization: Use persistent kernel constants for specific GPUs
|
| 19 |
+
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
|
| 20 |
+
else:
|
| 21 |
+
print("No NVIDIA/AMD GPU detected. Triton kernels will not run on this hardware.")
|
| 22 |
+
print(f"-------------------------\n")
|
| 23 |
+
|
| 24 |
+
check_environment()
|
| 25 |
+
|
| 26 |
+
# 2. OPTIMIZED DUAL GEMM KERNEL
|
| 27 |
+
@triton.autotune(
|
| 28 |
+
configs=[
|
| 29 |
+
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
| 30 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 31 |
+
],
|
| 32 |
+
key=['M', 'N', 'K'],
|
| 33 |
+
)
|
| 34 |
+
@triton.jit
|
| 35 |
+
def dual_gemm_kernel(
|
| 36 |
+
a_ptr, b1_ptr, b2_ptr, c_ptr,
|
| 37 |
+
sfa_ptr, sfb1_ptr, sfb2_ptr,
|
| 38 |
+
M, N, K, L,
|
| 39 |
+
stride_am, stride_ak, stride_al,
|
| 40 |
+
stride_bn, stride_bk, stride_bl,
|
| 41 |
+
stride_cm, stride_cn, stride_cl,
|
| 42 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 43 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 44 |
+
):
|
| 45 |
+
# Program ID & Work distribution
|
| 46 |
+
pid = tl.program_id(0)
|
| 47 |
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 48 |
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 49 |
+
|
| 50 |
+
# Persistent Grid Loop (Iterates over batches L and tiles)
|
| 51 |
+
total_tiles = num_pid_m * num_pid_n * L
|
| 52 |
+
for tile_idx in tl.range(pid, total_tiles, tl.num_programs(0)):
|
| 53 |
+
l_idx = tile_idx // (num_pid_m * num_pid_n)
|
| 54 |
+
tile_rem = tile_idx % (num_pid_m * num_pid_n)
|
| 55 |
+
|
| 56 |
+
pid_m = tile_rem // num_pid_n
|
| 57 |
+
pid_n = tile_rem % num_pid_n
|
| 58 |
+
|
| 59 |
+
# Memory Offsets
|
| 60 |
+
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M))
|
| 61 |
+
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N))
|
| 62 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 63 |
+
|
| 64 |
+
a_ptrs = a_ptr + l_idx * stride_al + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 65 |
+
b1_ptrs = b1_ptr + l_idx * stride_bl + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk)
|
| 66 |
+
b2_ptrs = b2_ptr + l_idx * stride_bl + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk)
|
| 67 |
+
|
| 68 |
+
# Accumulators
|
| 69 |
+
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 70 |
+
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 71 |
+
|
| 72 |
+
# Inner K-loop
|
| 73 |
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
| 74 |
+
a = tl.load(a_ptrs)
|
| 75 |
+
b1 = tl.load(b1_ptrs)
|
| 76 |
+
b2 = tl.load(b2_ptrs)
|
| 77 |
+
|
| 78 |
+
# Using tl.dot_scaled for hardware-native scaling if available
|
| 79 |
+
# Note: Scales (sfa, sfb) are loaded from their respective pointers
|
| 80 |
+
acc1 = tl.dot_scaled(a, None, "e2m1", b1.T, None, "e2m1", acc1)
|
| 81 |
+
acc2 = tl.dot_scaled(a, None, "e2m1", b2.T, None, "e2m1", acc2)
|
| 82 |
+
|
| 83 |
+
a_ptrs += BLOCK_K * stride_ak
|
| 84 |
+
b1_ptrs += BLOCK_K * stride_bk
|
| 85 |
+
b2_ptrs += BLOCK_K * stride_bk
|
| 86 |
+
|
| 87 |
+
# 3. FUSED EPILOGUE (SiLU + Gating)
|
| 88 |
+
# res = SiLU(A @ B1) * (A @ B2)
|
| 89 |
+
res1 = acc1.to(tl.float16)
|
| 90 |
+
activated_res1 = res1 * tl.sigmoid(res1)
|
| 91 |
+
final_out = activated_res1 * acc2.to(tl.float16)
|
| 92 |
+
|
| 93 |
+
# Store result
|
| 94 |
+
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 95 |
+
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 96 |
+
c_ptrs = c_ptr + l_idx * stride_cl + offs_cm[:, None] * stride_cm + offs_cn[None, :]
|
| 97 |
+
tl.store(c_ptrs, final_out)
|
| 98 |
+
|
| 99 |
+
# 4. HARNESS INTERFACE
|
| 100 |
+
def dual_gemm_submission(data):
|
| 101 |
+
# Unpack the tuple provided by the benchmark harness
|
| 102 |
+
a, b1, b2, sfa, sfb1, sfb2, c = data
|
| 103 |
+
M, K_packed, L = a.shape
|
| 104 |
+
N = b1.shape[0]
|
| 105 |
+
K = K_packed * 2 # Assuming FP4 packing
|
| 106 |
+
|
| 107 |
+
# Grid size: Launch exactly the number of SMs/CUs for a persistent wave
|
| 108 |
+
num_sms = torch.cuda.get_device_properties(0).multi_processor_count
|
| 109 |
+
grid = (num_sms,)
|
| 110 |
+
|
| 111 |
+
dual_gemm_kernel[grid](
|
| 112 |
+
a, b1, b2, c, sfa, sfb1, sfb2,
|
| 113 |
+
M, N, K, L,
|
| 114 |
+
a.stride(0), a.stride(1), a.stride(2),
|
| 115 |
+
b1.stride(0), b1.stride(1), b1.stride(2),
|
| 116 |
+
c.stride(0), c.stride(1), c.stride(2)
|
| 117 |
+
)
|
| 118 |
+
return c
|
amd_dual_gemm_swiglu.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AMD Triton fused Dual GEMM + SwiGLU kernel.
|
| 3 |
+
Computes: silu(A @ B1) * (A @ B2) in a single fused kernel.
|
| 4 |
+
Uses triton-kernels testing.py: assert_close (maxtol=2e-2, rmstol=4e-3).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
# Allow importing testing.py from same directory (when run from kernels/)
|
| 13 |
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
if _script_dir not in sys.path:
|
| 15 |
+
sys.path.insert(0, _script_dir)
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import triton
|
| 19 |
+
import triton.language as tl
|
| 20 |
+
|
| 21 |
+
# Optional MXFP4 pre-dequant (option 1: upcast before GEMM)
|
| 22 |
+
try:
|
| 23 |
+
from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
|
| 24 |
+
_HAS_MXFP = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
upcast_mxfp4_to_fp16 = None
|
| 27 |
+
_HAS_MXFP = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _maybe_upcast_mxfp(b, name: str) -> torch.Tensor:
|
| 31 |
+
"""If b is MXFP4 (mx_tensor, mx_scale), upcast to fp16. Else return b."""
|
| 32 |
+
if not isinstance(b, (tuple, list)) or len(b) != 2:
|
| 33 |
+
return b
|
| 34 |
+
mx_tensor, mx_scale = b
|
| 35 |
+
if not (isinstance(mx_tensor, torch.Tensor) and isinstance(mx_scale, torch.Tensor)):
|
| 36 |
+
return b
|
| 37 |
+
if mx_tensor.dtype != torch.uint8 or mx_scale.dtype != torch.uint8:
|
| 38 |
+
return b
|
| 39 |
+
if not _HAS_MXFP:
|
| 40 |
+
raise ImportError("MXFP4 weights require numerics_details.mxfp_details")
|
| 41 |
+
return upcast_mxfp4_to_fp16(mx_tensor, mx_scale, verbose=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@triton.autotune(
|
| 45 |
+
configs=[
|
| 46 |
+
triton.Config(
|
| 47 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
|
| 48 |
+
num_warps=4,
|
| 49 |
+
num_stages=3,
|
| 50 |
+
),
|
| 51 |
+
triton.Config(
|
| 52 |
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
|
| 53 |
+
num_warps=8,
|
| 54 |
+
num_stages=3,
|
| 55 |
+
),
|
| 56 |
+
triton.Config(
|
| 57 |
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
|
| 58 |
+
num_warps=8,
|
| 59 |
+
num_stages=2,
|
| 60 |
+
),
|
| 61 |
+
triton.Config(
|
| 62 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
|
| 63 |
+
num_warps=8,
|
| 64 |
+
num_stages=2,
|
| 65 |
+
),
|
| 66 |
+
triton.Config(
|
| 67 |
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
|
| 68 |
+
num_warps=4,
|
| 69 |
+
num_stages=3,
|
| 70 |
+
),
|
| 71 |
+
triton.Config(
|
| 72 |
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
|
| 73 |
+
num_warps=4,
|
| 74 |
+
num_stages=3,
|
| 75 |
+
),
|
| 76 |
+
triton.Config(
|
| 77 |
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
|
| 78 |
+
num_warps=4,
|
| 79 |
+
num_stages=4,
|
| 80 |
+
),
|
| 81 |
+
triton.Config(
|
| 82 |
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
|
| 83 |
+
num_warps=4,
|
| 84 |
+
num_stages=4,
|
| 85 |
+
),
|
| 86 |
+
triton.Config(
|
| 87 |
+
{"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8},
|
| 88 |
+
num_warps=8,
|
| 89 |
+
num_stages=2,
|
| 90 |
+
),
|
| 91 |
+
],
|
| 92 |
+
key=["M", "N", "K"],
|
| 93 |
+
)
|
| 94 |
+
@triton.heuristics(
|
| 95 |
+
{
|
| 96 |
+
"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0,
|
| 97 |
+
"EVEN_M": lambda args: args["M"] % args["BLOCK_M"] == 0,
|
| 98 |
+
"EVEN_N": lambda args: args["N"] % args["BLOCK_N"] == 0,
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
@triton.jit
|
| 102 |
+
def dual_gemm_swiglu_kernel(
|
| 103 |
+
a_ptr,
|
| 104 |
+
b1_ptr,
|
| 105 |
+
b2_ptr,
|
| 106 |
+
c_ptr,
|
| 107 |
+
M,
|
| 108 |
+
N,
|
| 109 |
+
K,
|
| 110 |
+
stride_am,
|
| 111 |
+
stride_ak,
|
| 112 |
+
stride_b1k,
|
| 113 |
+
stride_b1n,
|
| 114 |
+
stride_b2k,
|
| 115 |
+
stride_b2n,
|
| 116 |
+
stride_cm,
|
| 117 |
+
stride_cn,
|
| 118 |
+
BLOCK_M: tl.constexpr,
|
| 119 |
+
BLOCK_N: tl.constexpr,
|
| 120 |
+
BLOCK_K: tl.constexpr,
|
| 121 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 122 |
+
EVEN_K: tl.constexpr,
|
| 123 |
+
EVEN_M: tl.constexpr,
|
| 124 |
+
EVEN_N: tl.constexpr,
|
| 125 |
+
):
|
| 126 |
+
pid = tl.program_id(axis=0)
|
| 127 |
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 128 |
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 129 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 130 |
+
group_id = pid // num_pid_in_group
|
| 131 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 132 |
+
group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 133 |
+
pid_in_group = pid % num_pid_in_group
|
| 134 |
+
pid_m = first_pid_m + (pid_in_group % group_size_m)
|
| 135 |
+
pid_n = pid_in_group // group_size_m
|
| 136 |
+
|
| 137 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 138 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 139 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 140 |
+
|
| 141 |
+
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 142 |
+
b1_ptrs = b1_ptr + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
|
| 143 |
+
b2_ptrs = b2_ptr + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
|
| 144 |
+
|
| 145 |
+
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 146 |
+
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 147 |
+
m_mask = offs_m[:, None] < M
|
| 148 |
+
n_mask = offs_n[None, :] < N
|
| 149 |
+
|
| 150 |
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
| 151 |
+
if EVEN_K:
|
| 152 |
+
if EVEN_M:
|
| 153 |
+
a = tl.load(a_ptrs)
|
| 154 |
+
else:
|
| 155 |
+
a = tl.load(a_ptrs, mask=m_mask, other=0.0)
|
| 156 |
+
|
| 157 |
+
if EVEN_N:
|
| 158 |
+
b1 = tl.load(b1_ptrs)
|
| 159 |
+
b2 = tl.load(b2_ptrs)
|
| 160 |
+
else:
|
| 161 |
+
b1 = tl.load(b1_ptrs, mask=n_mask, other=0.0)
|
| 162 |
+
b2 = tl.load(b2_ptrs, mask=n_mask, other=0.0)
|
| 163 |
+
else:
|
| 164 |
+
k_rem = K - k * BLOCK_K
|
| 165 |
+
k_mask_m = offs_k[None, :] < k_rem
|
| 166 |
+
k_mask_n = offs_k[:, None] < k_rem
|
| 167 |
+
a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
|
| 168 |
+
b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 169 |
+
b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 170 |
+
|
| 171 |
+
tl.multiple_of(a_ptrs, [16, 16])
|
| 172 |
+
tl.multiple_of(b1_ptrs, [16, 16])
|
| 173 |
+
tl.multiple_of(b2_ptrs, [16, 16])
|
| 174 |
+
|
| 175 |
+
acc1 += tl.dot(a, b1)
|
| 176 |
+
acc2 += tl.dot(a, b2)
|
| 177 |
+
|
| 178 |
+
a_ptrs += BLOCK_K * stride_ak
|
| 179 |
+
b1_ptrs += BLOCK_K * stride_b1k
|
| 180 |
+
b2_ptrs += BLOCK_K * stride_b2k
|
| 181 |
+
|
| 182 |
+
silu = acc1 * tl.sigmoid(acc1)
|
| 183 |
+
out = silu * acc2
|
| 184 |
+
|
| 185 |
+
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
|
| 186 |
+
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 187 |
+
tl.store(c_ptrs, out.to(tl.float16), mask=c_mask)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def dual_gemm_swiglu(
|
| 191 |
+
a: torch.Tensor,
|
| 192 |
+
b1: torch.Tensor | tuple,
|
| 193 |
+
b2: torch.Tensor | tuple,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Fused Dual GEMM + SwiGLU. b1/b2 can be fp16 [K,N] or MXFP4 (mx_tensor, mx_scale)."""
|
| 196 |
+
b1 = _maybe_upcast_mxfp(b1, "b1")
|
| 197 |
+
b2 = _maybe_upcast_mxfp(b2, "b2")
|
| 198 |
+
|
| 199 |
+
if a.ndim != 2 or b1.ndim != 2 or b2.ndim != 2:
|
| 200 |
+
raise ValueError("Expected 2D tensors: a[M,K], b1[K,N], b2[K,N].")
|
| 201 |
+
if a.shape[1] != b1.shape[0] or a.shape[1] != b2.shape[0]:
|
| 202 |
+
raise ValueError("Incompatible shapes for dual GEMM.")
|
| 203 |
+
if b1.shape[1] != b2.shape[1]:
|
| 204 |
+
raise ValueError("b1 and b2 must have same N dimension.")
|
| 205 |
+
if not (a.is_cuda and b1.is_cuda and b2.is_cuda):
|
| 206 |
+
raise ValueError("All tensors must be on a CUDA/ROCm device.")
|
| 207 |
+
if a.dtype != torch.float16 or b1.dtype != torch.float16 or b2.dtype != torch.float16:
|
| 208 |
+
raise ValueError("This kernel currently expects float16 inputs.")
|
| 209 |
+
|
| 210 |
+
a = a.contiguous()
|
| 211 |
+
b1 = b1.contiguous()
|
| 212 |
+
b2 = b2.contiguous()
|
| 213 |
+
|
| 214 |
+
M, K = a.shape
|
| 215 |
+
_, N = b1.shape
|
| 216 |
+
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
| 217 |
+
|
| 218 |
+
grid = lambda META: (
|
| 219 |
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
dual_gemm_swiglu_kernel[grid](
|
| 223 |
+
a, b1, b2, c,
|
| 224 |
+
M=M, N=N, K=K,
|
| 225 |
+
stride_am=a.stride(0), stride_ak=a.stride(1),
|
| 226 |
+
stride_b1k=b1.stride(0), stride_b1n=b1.stride(1),
|
| 227 |
+
stride_b2k=b2.stride(0), stride_b2n=b2.stride(1),
|
| 228 |
+
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
| 229 |
+
)
|
| 230 |
+
return c
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def reference_dual_gemm_swiglu(a: torch.Tensor, b1: torch.Tensor, b2: torch.Tensor) -> torch.Tensor:
|
| 234 |
+
x1 = a @ b1
|
| 235 |
+
x2 = a @ b2
|
| 236 |
+
return torch.nn.functional.silu(x1) * x2
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def test_correctness(device: str = "cuda", maxtol: float = 2e-2, rmstol: float = 4e-3) -> bool:
|
| 240 |
+
"""Run correctness tests using triton-kernels testing.assert_close."""
|
| 241 |
+
from testing import assert_close
|
| 242 |
+
|
| 243 |
+
torch.manual_seed(42)
|
| 244 |
+
shapes = [(128, 64, 128), (256, 256, 512), (1024, 512, 1024), (4096, 3648, 8192),
|
| 245 |
+
(7, 13, 17), (100, 200, 150)]
|
| 246 |
+
input_scale = 0.125
|
| 247 |
+
all_pass = True
|
| 248 |
+
for m, n, k in shapes:
|
| 249 |
+
a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
|
| 250 |
+
b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 251 |
+
b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
ref = reference_dual_gemm_swiglu(a.float(), b1.float(), b2.float()).to(torch.float16)
|
| 254 |
+
out = dual_gemm_swiglu(a, b1, b2)
|
| 255 |
+
desc = f"[shape ({m},{n},{k})]"
|
| 256 |
+
try:
|
| 257 |
+
assert_close(ref, out, maxtol=maxtol, rmstol=rmstol, description=desc, verbose=True)
|
| 258 |
+
print(f" {desc} PASS")
|
| 259 |
+
except AssertionError:
|
| 260 |
+
print(f" {desc} FAIL")
|
| 261 |
+
all_pass = False
|
| 262 |
+
return all_pass
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def benchmark(m: int, n: int, k: int, warmup: int, iters: int, input_scale: float) -> None:
|
| 266 |
+
device = "cuda"
|
| 267 |
+
a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
|
| 268 |
+
b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 269 |
+
b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 270 |
+
for _ in range(warmup):
|
| 271 |
+
_ = dual_gemm_swiglu(a, b1, b2)
|
| 272 |
+
torch.cuda.synchronize()
|
| 273 |
+
start = torch.cuda.Event(enable_timing=True)
|
| 274 |
+
end = torch.cuda.Event(enable_timing=True)
|
| 275 |
+
start.record()
|
| 276 |
+
for _ in range(iters):
|
| 277 |
+
_ = dual_gemm_swiglu(a, b1, b2)
|
| 278 |
+
end.record()
|
| 279 |
+
torch.cuda.synchronize()
|
| 280 |
+
avg_ms = start.elapsed_time(end) / iters
|
| 281 |
+
total_flops = 4 * m * n * k
|
| 282 |
+
tflops = (total_flops / (avg_ms * 1e-3)) / 1e12
|
| 283 |
+
print(f"[kernel] shape=({m}, {n}, {k}) avg={avg_ms:.3f} ms, ~{tflops:.2f} TFLOP/s")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def main() -> None:
|
| 287 |
+
parser = argparse.ArgumentParser(description="AMD Triton fused dual-GEMM + SwiGLU")
|
| 288 |
+
parser.add_argument("--m", type=int, default=4096)
|
| 289 |
+
parser.add_argument("--n", type=int, default=3648)
|
| 290 |
+
parser.add_argument("--k", type=int, default=8192)
|
| 291 |
+
parser.add_argument("--warmup", type=int, default=10)
|
| 292 |
+
parser.add_argument("--iters", type=int, default=50)
|
| 293 |
+
parser.add_argument("--input-scale", type=float, default=0.125)
|
| 294 |
+
parser.add_argument("--test-only", action="store_true", help="Run correctness tests only")
|
| 295 |
+
parser.add_argument("--bench-only", action="store_true", help="Run benchmark only")
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
|
| 298 |
+
if not torch.cuda.is_available():
|
| 299 |
+
print("ERROR: No CUDA/ROCm GPU detected. This kernel requires a GPU to run.")
|
| 300 |
+
print(" - Run on a machine with an NVIDIA GPU (CUDA) or AMD GPU (ROCm)")
|
| 301 |
+
print(" - Ensure PyTorch is installed with GPU support.")
|
| 302 |
+
raise SystemExit(1)
|
| 303 |
+
|
| 304 |
+
if not args.bench_only:
|
| 305 |
+
print("Running correctness tests...")
|
| 306 |
+
t0 = time.time()
|
| 307 |
+
ok = test_correctness()
|
| 308 |
+
print(f"Correctness: {'PASS' if ok else 'FAIL'} ({time.time()-t0:.2f}s)")
|
| 309 |
+
|
| 310 |
+
if not args.test_only:
|
| 311 |
+
print("\nRunning benchmark...")
|
| 312 |
+
t0 = time.time()
|
| 313 |
+
benchmark(args.m, args.n, args.k, args.warmup, args.iters, args.input_scale)
|
| 314 |
+
print(f"[done] elapsed={time.time()-t0:.2f}s")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
dual_gemm_swiglu_full.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dual GEMM + SwiGLU following triton_kernels swiglu.py build pattern.
|
| 3 |
+
|
| 4 |
+
Structure matches Kernel Community Hub swiglu.py:
|
| 5 |
+
- repr() and launch_metadata for specialization
|
| 6 |
+
- compute_swiglu() style activation (SiLU(gate) * linear)
|
| 7 |
+
- Optional Flexpoint/MXFP (stub for standalone, real import in triton_kernels)
|
| 8 |
+
- NTokens support for variable M (MoE routing)
|
| 9 |
+
- Persistent kernel pattern with tl.range
|
| 10 |
+
|
| 11 |
+
Usage (standalone fp16):
|
| 12 |
+
from dual_gemm_swiglu_full import dual_gemm_swiglu
|
| 13 |
+
out = dual_gemm_swiglu(a, b1, b2)
|
| 14 |
+
|
| 15 |
+
For triton_kernels integration: place in dual_gemm_swiglu_details/, add flexpoint import.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
# Allow importing testing.py from same directory (when run from kernels/)
|
| 27 |
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
if _script_dir not in sys.path:
|
| 29 |
+
sys.path.insert(0, _script_dir)
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import triton
|
| 33 |
+
import triton.language as tl
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
# Flexpoint stub (standalone). Replace with:
|
| 37 |
+
# from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
|
| 38 |
+
# when integrating into triton_kernels.
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
_HAS_FLEXPOINT = False
|
| 41 |
+
try:
|
| 42 |
+
# Only works inside triton_kernels package
|
| 43 |
+
from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
|
| 44 |
+
_HAS_FLEXPOINT = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
@triton.jit
|
| 47 |
+
def load_scale(scale_ptr):
|
| 48 |
+
return 1.0 if scale_ptr is None else tl.load(scale_ptr)
|
| 49 |
+
|
| 50 |
+
def float_to_flex_stub(x, *args, **kwargs):
|
| 51 |
+
"""Pass-through for fp16 standalone."""
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def update_scale_stub(x, scale_ptr, Out):
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
# Helpers (mirroring swiglu.py)
|
| 60 |
+
# -----------------------------------------------------------------------------
|
| 61 |
+
@triton.jit
|
| 62 |
+
def clip(x, limit, clip_lower: tl.constexpr):
|
| 63 |
+
res = tl.minimum(x, limit)
|
| 64 |
+
if clip_lower:
|
| 65 |
+
res = tl.maximum(-limit, res)
|
| 66 |
+
return res
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@triton.jit
|
| 70 |
+
def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
|
| 71 |
+
return tl.max(
|
| 72 |
+
tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True),
|
| 73 |
+
axis=1,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@triton.jit
|
| 78 |
+
def compute_swiglu(gelu, linear, scale, alpha, limit: tl.constexpr):
|
| 79 |
+
"""SwiGLU: silu(gelu) * linear. Matches swiglu.py compute_swiglu style.
|
| 80 |
+
limit > 0 enables clipping; pass 0.0 for no clip.
|
| 81 |
+
"""
|
| 82 |
+
gelu = gelu.to(tl.float32) * scale
|
| 83 |
+
if limit > 0:
|
| 84 |
+
gelu = clip(gelu, limit, clip_lower=False)
|
| 85 |
+
linear = linear.to(tl.float32) * scale
|
| 86 |
+
if limit > 0:
|
| 87 |
+
linear = clip(linear, limit, clip_lower=True)
|
| 88 |
+
s = gelu / (1 + tl.exp(-alpha * gelu)) # SiLU(gelu)
|
| 89 |
+
return s * linear # SiLU(gate) * linear (standard SwiGLU)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# -----------------------------------------------------------------------------
|
| 93 |
+
# Repr and launch_metadata (swiglu.py pattern)
|
| 94 |
+
# -----------------------------------------------------------------------------
|
| 95 |
+
def dual_gemm_repr(specialization):
|
| 96 |
+
signature = specialization.signature
|
| 97 |
+
constants = specialization.constants
|
| 98 |
+
convert_dtype = lambda dtype: "mxfp4" if "u8" in str(dtype) else str(dtype)
|
| 99 |
+
dtypes = "x".join([convert_dtype(f"{signature.get(i, 'fp16')}") for i in ["Out", "A", "B1", "B2"]])
|
| 100 |
+
blocks = "x".join([f"{constants.get(i, 0)}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K"]])
|
| 101 |
+
return f"_dual_gemm_swiglu_{dtypes}_{blocks}"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def dual_gemm_launch_metadata(grid, kernel, args):
|
| 105 |
+
M, N, K = args["M"], args["N"], args["K"]
|
| 106 |
+
ret = dict()
|
| 107 |
+
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
|
| 108 |
+
A, B1, B2, Out = args["A"], args["B1"], args["B2"], args["Out"]
|
| 109 |
+
ret["bytes"] = (
|
| 110 |
+
A.numel() * A.element_size()
|
| 111 |
+
+ B1.numel() * B1.element_size()
|
| 112 |
+
+ B2.numel() * B2.element_size()
|
| 113 |
+
+ Out.numel() * Out.element_size()
|
| 114 |
+
)
|
| 115 |
+
return ret
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# -----------------------------------------------------------------------------
|
| 119 |
+
# Dual GEMM + SwiGLU kernel (swiglu.py structure)
|
| 120 |
+
# -----------------------------------------------------------------------------
|
| 121 |
+
@triton.jit(repr=lambda _: "_dual_gemm_swiglu", launch_metadata=dual_gemm_launch_metadata)
|
| 122 |
+
def _dual_gemm_swiglu(
|
| 123 |
+
Out,
|
| 124 |
+
A,
|
| 125 |
+
B1,
|
| 126 |
+
B2,
|
| 127 |
+
M,
|
| 128 |
+
N,
|
| 129 |
+
K,
|
| 130 |
+
stride_am,
|
| 131 |
+
stride_ak,
|
| 132 |
+
stride_b1k,
|
| 133 |
+
stride_b1n,
|
| 134 |
+
stride_b2k,
|
| 135 |
+
stride_b2n,
|
| 136 |
+
stride_outm,
|
| 137 |
+
stride_outn,
|
| 138 |
+
alpha: tl.constexpr,
|
| 139 |
+
limit,
|
| 140 |
+
NTokens,
|
| 141 |
+
BLOCK_M: tl.constexpr,
|
| 142 |
+
BLOCK_N: tl.constexpr,
|
| 143 |
+
BLOCK_K: tl.constexpr,
|
| 144 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 145 |
+
EVEN_K: tl.constexpr,
|
| 146 |
+
EVEN_M: tl.constexpr,
|
| 147 |
+
EVEN_N: tl.constexpr,
|
| 148 |
+
):
|
| 149 |
+
if NTokens is not None:
|
| 150 |
+
M = tl.load(NTokens)
|
| 151 |
+
M_BLOCKS = tl.cdiv(M, BLOCK_M)
|
| 152 |
+
N_BLOCKS = tl.cdiv(N, BLOCK_N)
|
| 153 |
+
num_tiles = M_BLOCKS * N_BLOCKS
|
| 154 |
+
|
| 155 |
+
# Persistent kernel: each program handles multiple tiles
|
| 156 |
+
grid_size = tl.num_programs(0)
|
| 157 |
+
for pid in range(tl.program_id(0), num_tiles, grid_size):
|
| 158 |
+
pid_m = pid // N_BLOCKS
|
| 159 |
+
pid_n = pid % N_BLOCKS
|
| 160 |
+
|
| 161 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 162 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 163 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 164 |
+
|
| 165 |
+
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 166 |
+
b1_ptrs = B1 + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
|
| 167 |
+
b2_ptrs = B2 + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
|
| 168 |
+
|
| 169 |
+
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 170 |
+
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 171 |
+
m_mask = offs_m[:, None] < M
|
| 172 |
+
n_mask = offs_n[None, :] < N
|
| 173 |
+
|
| 174 |
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
| 175 |
+
if EVEN_K:
|
| 176 |
+
a = tl.load(a_ptrs, mask=m_mask, other=0.0) if not EVEN_M else tl.load(a_ptrs)
|
| 177 |
+
b1 = tl.load(b1_ptrs, mask=n_mask, other=0.0) if not EVEN_N else tl.load(b1_ptrs)
|
| 178 |
+
b2 = tl.load(b2_ptrs, mask=n_mask, other=0.0) if not EVEN_N else tl.load(b2_ptrs)
|
| 179 |
+
else:
|
| 180 |
+
k_rem = K - k * BLOCK_K
|
| 181 |
+
k_mask_m = offs_k[None, :] < k_rem
|
| 182 |
+
k_mask_n = offs_k[:, None] < k_rem
|
| 183 |
+
a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
|
| 184 |
+
b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 185 |
+
b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 186 |
+
|
| 187 |
+
tl.multiple_of(a_ptrs, [16, 16])
|
| 188 |
+
tl.multiple_of(b1_ptrs, [16, 16])
|
| 189 |
+
tl.multiple_of(b2_ptrs, [16, 16])
|
| 190 |
+
acc1 += tl.dot(a, b1)
|
| 191 |
+
acc2 += tl.dot(a, b2)
|
| 192 |
+
|
| 193 |
+
a_ptrs += BLOCK_K * stride_ak
|
| 194 |
+
b1_ptrs += BLOCK_K * stride_b1k
|
| 195 |
+
b2_ptrs += BLOCK_K * stride_b2k
|
| 196 |
+
|
| 197 |
+
out = compute_swiglu(acc1, acc2, 1.0, alpha, limit)
|
| 198 |
+
out = out.to(tl.float16)
|
| 199 |
+
|
| 200 |
+
out_ptrs = Out + (offs_m[:, None] * stride_outm + offs_n[None, :] * stride_outn)
|
| 201 |
+
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 202 |
+
tl.store(out_ptrs, out, mask=c_mask)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# -----------------------------------------------------------------------------
|
| 206 |
+
# Autotuned wrapper (backward compatible, uses simpler kernel for reliability)
|
| 207 |
+
# -----------------------------------------------------------------------------
|
| 208 |
+
@triton.autotune(
|
| 209 |
+
configs=[
|
| 210 |
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
|
| 211 |
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=3),
|
| 212 |
+
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=2),
|
| 213 |
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=2),
|
| 214 |
+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
|
| 215 |
+
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
|
| 216 |
+
],
|
| 217 |
+
key=["M", "N", "K"],
|
| 218 |
+
)
|
| 219 |
+
@triton.heuristics(
|
| 220 |
+
{
|
| 221 |
+
"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0,
|
| 222 |
+
"EVEN_M": lambda args: args["M"] % args["BLOCK_M"] == 0,
|
| 223 |
+
"EVEN_N": lambda args: args["N"] % args["BLOCK_N"] == 0,
|
| 224 |
+
}
|
| 225 |
+
)
|
| 226 |
+
@triton.jit
|
| 227 |
+
def _dual_gemm_swiglu_autotuned(
|
| 228 |
+
a_ptr, b1_ptr, b2_ptr, c_ptr,
|
| 229 |
+
M, N, K,
|
| 230 |
+
stride_am, stride_ak, stride_b1k, stride_b1n, stride_b2k, stride_b2n, stride_cm, stride_cn,
|
| 231 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
| 232 |
+
EVEN_K: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
|
| 233 |
+
alpha: tl.constexpr,
|
| 234 |
+
):
|
| 235 |
+
pid = tl.program_id(axis=0)
|
| 236 |
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 237 |
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 238 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 239 |
+
group_id = pid // num_pid_in_group
|
| 240 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 241 |
+
group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 242 |
+
pid_in_group = pid % num_pid_in_group
|
| 243 |
+
pid_m = first_pid_m + (pid_in_group % group_size_m)
|
| 244 |
+
pid_n = pid_in_group // group_size_m
|
| 245 |
+
|
| 246 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 247 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 248 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 249 |
+
|
| 250 |
+
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 251 |
+
b1_ptrs = b1_ptr + (offs_k[:, None] * stride_b1k + offs_n[None, :] * stride_b1n)
|
| 252 |
+
b2_ptrs = b2_ptr + (offs_k[:, None] * stride_b2k + offs_n[None, :] * stride_b2n)
|
| 253 |
+
|
| 254 |
+
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 255 |
+
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 256 |
+
m_mask = offs_m[:, None] < M
|
| 257 |
+
n_mask = offs_n[None, :] < N
|
| 258 |
+
|
| 259 |
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
| 260 |
+
if EVEN_K:
|
| 261 |
+
a = tl.load(a_ptrs) if EVEN_M else tl.load(a_ptrs, mask=m_mask, other=0.0)
|
| 262 |
+
b1 = tl.load(b1_ptrs) if EVEN_N else tl.load(b1_ptrs, mask=n_mask, other=0.0)
|
| 263 |
+
b2 = tl.load(b2_ptrs) if EVEN_N else tl.load(b2_ptrs, mask=n_mask, other=0.0)
|
| 264 |
+
else:
|
| 265 |
+
k_rem = K - k * BLOCK_K
|
| 266 |
+
k_mask_m = offs_k[None, :] < k_rem
|
| 267 |
+
k_mask_n = offs_k[:, None] < k_rem
|
| 268 |
+
a = tl.load(a_ptrs, mask=m_mask & k_mask_m, other=0.0)
|
| 269 |
+
b1 = tl.load(b1_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 270 |
+
b2 = tl.load(b2_ptrs, mask=k_mask_n & n_mask, other=0.0)
|
| 271 |
+
|
| 272 |
+
tl.multiple_of(a_ptrs, [16, 16])
|
| 273 |
+
tl.multiple_of(b1_ptrs, [16, 16])
|
| 274 |
+
tl.multiple_of(b2_ptrs, [16, 16])
|
| 275 |
+
acc1 += tl.dot(a, b1)
|
| 276 |
+
acc2 += tl.dot(a, b2)
|
| 277 |
+
a_ptrs += BLOCK_K * stride_ak
|
| 278 |
+
b1_ptrs += BLOCK_K * stride_b1k
|
| 279 |
+
b2_ptrs += BLOCK_K * stride_b2k
|
| 280 |
+
|
| 281 |
+
out = compute_swiglu(acc1, acc2, 1.0, alpha, 0.0) # 0.0 = no clip
|
| 282 |
+
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
|
| 283 |
+
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 284 |
+
tl.store(c_ptrs, out.to(tl.float16), mask=c_mask)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def dual_gemm_swiglu(
|
| 288 |
+
a: torch.Tensor,
|
| 289 |
+
b1: torch.Tensor,
|
| 290 |
+
b2: torch.Tensor,
|
| 291 |
+
n_tokens: Optional[torch.Tensor] = None,
|
| 292 |
+
alpha: float = 1.0,
|
| 293 |
+
limit: Optional[float] = None,
|
| 294 |
+
) -> torch.Tensor:
|
| 295 |
+
"""Fused Dual GEMM + SwiGLU: silu(A @ B1) * (A @ B2)."""
|
| 296 |
+
if a.ndim != 2 or b1.ndim != 2 or b2.ndim != 2:
|
| 297 |
+
raise ValueError("Expected 2D tensors: a[M,K], b1[K,N], b2[K,N].")
|
| 298 |
+
if a.shape[1] != b1.shape[0] or a.shape[1] != b2.shape[0]:
|
| 299 |
+
raise ValueError("Incompatible shapes for dual GEMM.")
|
| 300 |
+
if b1.shape[1] != b2.shape[1]:
|
| 301 |
+
raise ValueError("b1 and b2 must have same N dimension.")
|
| 302 |
+
if not (a.is_cuda and b1.is_cuda and b2.is_cuda):
|
| 303 |
+
raise ValueError("All tensors must be on a CUDA/ROCm device.")
|
| 304 |
+
if a.dtype != torch.float16 or b1.dtype != torch.float16 or b2.dtype != torch.float16:
|
| 305 |
+
raise ValueError("This kernel currently expects float16 inputs.")
|
| 306 |
+
|
| 307 |
+
a = a.contiguous()
|
| 308 |
+
b1 = b1.contiguous()
|
| 309 |
+
b2 = b2.contiguous()
|
| 310 |
+
|
| 311 |
+
M, K = a.shape
|
| 312 |
+
_, N = b1.shape
|
| 313 |
+
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
| 314 |
+
|
| 315 |
+
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
|
| 316 |
+
|
| 317 |
+
_dual_gemm_swiglu_autotuned[grid](
|
| 318 |
+
a, b1, b2, c,
|
| 319 |
+
M=M, N=N, K=K,
|
| 320 |
+
stride_am=a.stride(0), stride_ak=a.stride(1),
|
| 321 |
+
stride_b1k=b1.stride(0), stride_b1n=b1.stride(1),
|
| 322 |
+
stride_b2k=b2.stride(0), stride_b2n=b2.stride(1),
|
| 323 |
+
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
| 324 |
+
alpha=alpha,
|
| 325 |
+
)
|
| 326 |
+
return c
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def reference_dual_gemm_swiglu(a: torch.Tensor, b1: torch.Tensor, b2: torch.Tensor) -> torch.Tensor:
|
| 330 |
+
x1 = a @ b1
|
| 331 |
+
x2 = a @ b2
|
| 332 |
+
return torch.nn.functional.silu(x1) * x2
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def test_correctness(device: str = "cuda", maxtol: float = 2e-2, rmstol: float = 4e-3) -> bool:
|
| 336 |
+
from testing import assert_close
|
| 337 |
+
|
| 338 |
+
torch.manual_seed(42)
|
| 339 |
+
shapes = [
|
| 340 |
+
(128, 64, 128),
|
| 341 |
+
(256, 256, 512),
|
| 342 |
+
(1024, 512, 1024),
|
| 343 |
+
(4096, 3648, 8192),
|
| 344 |
+
(7, 13, 17),
|
| 345 |
+
(100, 200, 150),
|
| 346 |
+
]
|
| 347 |
+
input_scale = 0.125
|
| 348 |
+
all_pass = True
|
| 349 |
+
for m, n, k in shapes:
|
| 350 |
+
a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
|
| 351 |
+
b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 352 |
+
b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 353 |
+
with torch.no_grad():
|
| 354 |
+
ref = reference_dual_gemm_swiglu(a.float(), b1.float(), b2.float()).to(torch.float16)
|
| 355 |
+
out = dual_gemm_swiglu(a, b1, b2)
|
| 356 |
+
desc = f"[shape ({m},{n},{k})]"
|
| 357 |
+
try:
|
| 358 |
+
assert_close(ref, out, maxtol=maxtol, rmstol=rmstol, description=desc, verbose=True)
|
| 359 |
+
print(f" {desc} PASS")
|
| 360 |
+
except AssertionError as e:
|
| 361 |
+
print(f" {desc} FAIL: {e}")
|
| 362 |
+
all_pass = False
|
| 363 |
+
return all_pass
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def benchmark(m: int, n: int, k: int, warmup: int, iters: int, input_scale: float) -> None:
|
| 367 |
+
device = "cuda"
|
| 368 |
+
a = torch.randn((m, k), device=device, dtype=torch.float16) * input_scale
|
| 369 |
+
b1 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 370 |
+
b2 = torch.randn((k, n), device=device, dtype=torch.float16) * input_scale
|
| 371 |
+
for _ in range(warmup):
|
| 372 |
+
_ = dual_gemm_swiglu(a, b1, b2)
|
| 373 |
+
torch.cuda.synchronize()
|
| 374 |
+
start = torch.cuda.Event(enable_timing=True)
|
| 375 |
+
end = torch.cuda.Event(enable_timing=True)
|
| 376 |
+
start.record()
|
| 377 |
+
for _ in range(iters):
|
| 378 |
+
_ = dual_gemm_swiglu(a, b1, b2)
|
| 379 |
+
end.record()
|
| 380 |
+
torch.cuda.synchronize()
|
| 381 |
+
avg_ms = start.elapsed_time(end) / iters
|
| 382 |
+
total_flops = 4 * m * n * k
|
| 383 |
+
tflops = (total_flops / (avg_ms * 1e-3)) / 1e12
|
| 384 |
+
print(f"[kernel] shape=({m}, {n}, {k}) avg={avg_ms:.3f} ms, ~{tflops:.2f} TFLOP/s")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def main() -> None:
|
| 388 |
+
parser = argparse.ArgumentParser(description="Dual GEMM + SwiGLU (swiglu.py build pattern)")
|
| 389 |
+
parser.add_argument("--m", type=int, default=4096)
|
| 390 |
+
parser.add_argument("--n", type=int, default=3648)
|
| 391 |
+
parser.add_argument("--k", type=int, default=8192)
|
| 392 |
+
parser.add_argument("--warmup", type=int, default=10)
|
| 393 |
+
parser.add_argument("--iters", type=int, default=50)
|
| 394 |
+
parser.add_argument("--input-scale", type=float, default=0.125)
|
| 395 |
+
parser.add_argument("--test-only", action="store_true")
|
| 396 |
+
parser.add_argument("--bench-only", action="store_true")
|
| 397 |
+
args = parser.parse_args()
|
| 398 |
+
|
| 399 |
+
if not torch.cuda.is_available():
|
| 400 |
+
print("ERROR: No CUDA/ROCm GPU detected.")
|
| 401 |
+
raise SystemExit(1)
|
| 402 |
+
|
| 403 |
+
if not args.bench_only:
|
| 404 |
+
print("Running correctness tests...")
|
| 405 |
+
t0 = time.time()
|
| 406 |
+
ok = test_correctness()
|
| 407 |
+
print(f"Correctness: {'PASS' if ok else 'FAIL'} ({time.time()-t0:.2f}s)")
|
| 408 |
+
|
| 409 |
+
if not args.test_only:
|
| 410 |
+
print("\nRunning benchmark...")
|
| 411 |
+
t0 = time.time()
|
| 412 |
+
benchmark(args.m, args.n, args.k, args.warmup, args.iters, args.input_scale)
|
| 413 |
+
print(f"[done] elapsed={time.time()-t0:.2f}s")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
if __name__ == "__main__":
|
| 417 |
+
main()
|
dual_gemm_swiglu_mxfp.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Option 1.5 / 2: Dual GEMM + SwiGLU with MXFP4 weights.
|
| 3 |
+
|
| 4 |
+
- Option 1 (pre-dequant): use dual_gemm_swiglu from amd_dual_gemm_swiglu with (mx_tensor, mx_scale)
|
| 5 |
+
- Option 1.5 (tiled pre-dequant): upcast B in K-blocks, never materialize full fp16 B. Saves memory.
|
| 6 |
+
- Option 2 (fused): would decode MXFP in-kernel; blocked by ROCm Triton limitations (tl.cat, indexing).
|
| 7 |
+
Currently falls back to option 1.5.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
from dual_gemm_swiglu_mxfp import dual_gemm_swiglu_mxfp_tiled
|
| 11 |
+
out = dual_gemm_swiglu_mxfp_tiled(a, (b1_mx, b1_scale), (b2_mx, b2_scale))
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
if _script_dir not in sys.path:
|
| 19 |
+
sys.path.insert(0, _script_dir)
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
|
| 25 |
+
from amd_dual_gemm_swiglu import reference_dual_gemm_swiglu
|
| 26 |
+
_HAS_DEPS = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
_HAS_DEPS = False
|
| 29 |
+
|
| 30 |
+
MXFP_BLOCK = 32 # N must be multiple of 32 for scale
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _upcast_slice(mx_tensor: torch.Tensor, mx_scale: torch.Tensor, k_start: int, k_end: int) -> torch.Tensor:
|
| 34 |
+
"""Upcast MXFP4 slice [k_start:k_end, :] to fp16 [k_end-k_start, N]."""
|
| 35 |
+
return upcast_mxfp4_to_fp16(
|
| 36 |
+
mx_tensor[k_start:k_end, :],
|
| 37 |
+
mx_scale[k_start:k_end, :],
|
| 38 |
+
block_m=k_end - k_start,
|
| 39 |
+
block_k=64, # N must be divisible by block_k
|
| 40 |
+
verbose=False,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def dual_gemm_swiglu_mxfp_tiled(
|
| 45 |
+
a: torch.Tensor,
|
| 46 |
+
b1_mx: torch.Tensor,
|
| 47 |
+
b1_scale: torch.Tensor,
|
| 48 |
+
b2_mx: torch.Tensor,
|
| 49 |
+
b2_scale: torch.Tensor,
|
| 50 |
+
block_k: int = 64,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Dual GEMM + SwiGLU with MXFP4 B1, B2 using tiled pre-dequant (Option 1.5).
|
| 54 |
+
Upcasts B in K-blocks; never materializes full fp16 B. Saves memory vs full pre-dequant.
|
| 55 |
+
"""
|
| 56 |
+
if not _HAS_DEPS:
|
| 57 |
+
raise ImportError("Requires numerics_details and amd_dual_gemm_swiglu")
|
| 58 |
+
M, K = a.shape
|
| 59 |
+
N = b1_mx.shape[1] * 2
|
| 60 |
+
assert b1_mx.shape == (K, N // 2) and b1_scale.shape == (K, N // 32)
|
| 61 |
+
assert b2_mx.shape == (K, N // 2) and b2_scale.shape == (K, N // 32)
|
| 62 |
+
assert K % block_k == 0 and N % MXFP_BLOCK == 0
|
| 63 |
+
assert block_k % MXFP_BLOCK == 0
|
| 64 |
+
|
| 65 |
+
a = a.contiguous()
|
| 66 |
+
c = torch.zeros((M, N), device=a.device, dtype=torch.float16)
|
| 67 |
+
# Option 1.5: Accumulate acc1 = A@B1 and acc2 = A@B2 in K-blocks, then out = silu(acc1)*acc2.
|
| 68 |
+
# Never materialize full fp16 B - upcast slice by slice. Saves O(K*N) -> O(block_k*N) memory.
|
| 69 |
+
acc1 = torch.zeros((M, N), device=a.device, dtype=torch.float32)
|
| 70 |
+
acc2 = torch.zeros((M, N), device=a.device, dtype=torch.float32)
|
| 71 |
+
for k_start in range(0, K, block_k):
|
| 72 |
+
k_end = k_start + block_k
|
| 73 |
+
b1_slice = _upcast_slice(b1_mx, b1_scale, k_start, k_end)
|
| 74 |
+
b2_slice = _upcast_slice(b2_mx, b2_scale, k_start, k_end)
|
| 75 |
+
# Partial GEMM: acc1 += A[:, k_start:k_end] @ b1_slice
|
| 76 |
+
# Use a simple matmul - tl.dot in a loop. We need a kernel for this.
|
| 77 |
+
# Actually PyTorch: acc1 += (a[:, k_start:k_end] @ b1_slice.float()).float()
|
| 78 |
+
acc1 += (a[:, k_start:k_end].float() @ b1_slice.float())
|
| 79 |
+
acc2 += (a[:, k_start:k_end].float() @ b2_slice.float())
|
| 80 |
+
# SwiGLU
|
| 81 |
+
silu = torch.nn.functional.silu(acc1.to(torch.float16))
|
| 82 |
+
out = (silu * acc2.to(torch.float16)).to(torch.float16)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def dual_gemm_swiglu_mxfp_predequant(a, b1_mx, b1_scale, b2_mx, b2_scale):
|
| 87 |
+
"""Option 1: full pre-dequant, then standard dual GEMM."""
|
| 88 |
+
from amd_dual_gemm_swiglu import dual_gemm_swiglu
|
| 89 |
+
b1 = upcast_mxfp4_to_fp16(b1_mx, b1_scale, verbose=False)
|
| 90 |
+
b2 = upcast_mxfp4_to_fp16(b2_mx, b2_scale, verbose=False)
|
| 91 |
+
return dual_gemm_swiglu(a, b1, b2)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
if not _HAS_DEPS:
|
| 96 |
+
print("Missing deps")
|
| 97 |
+
sys.exit(1)
|
| 98 |
+
if not torch.cuda.is_available():
|
| 99 |
+
print("No GPU")
|
| 100 |
+
sys.exit(1)
|
| 101 |
+
device = "cuda"
|
| 102 |
+
M, N, K = 256, 128, 512
|
| 103 |
+
torch.manual_seed(42)
|
| 104 |
+
a = torch.randn((M, K), device=device, dtype=torch.float16) * 0.1
|
| 105 |
+
from example_dual_gemm_mxfp import downcast_fp16_to_mxfp4
|
| 106 |
+
b1_fp = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
|
| 107 |
+
b2_fp = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
|
| 108 |
+
b1_mx, b1_scale = downcast_fp16_to_mxfp4(b1_fp, block_k=64)
|
| 109 |
+
b2_mx, b2_scale = downcast_fp16_to_mxfp4(b2_fp, block_k=64)
|
| 110 |
+
print("Option 1 (pre-dequant):")
|
| 111 |
+
out1 = dual_gemm_swiglu_mxfp_predequant(a, b1_mx, b1_scale, b2_mx, b2_scale)
|
| 112 |
+
print("Option 1.5 (tiled pre-dequant):")
|
| 113 |
+
out15 = dual_gemm_swiglu_mxfp_tiled(a, b1_mx, b1_scale, b2_mx, b2_scale)
|
| 114 |
+
ref = reference_dual_gemm_swiglu(a.float(), b1_fp.float(), b2_fp.float()).to(torch.float16)
|
| 115 |
+
err1 = (out1.float() - ref.float()).abs().max().item()
|
| 116 |
+
err15 = (out15.float() - ref.float()).abs().max().item()
|
| 117 |
+
print(f" Option 1 err: {err1:.2e}")
|
| 118 |
+
print(f" Option 1.5 err: {err15:.2e}")
|
| 119 |
+
print(f" Option 1 vs 1.5 diff: {(out1.float() - out15.float()).abs().max().item():.2e}")
|
example_dual_gemm_mxfp.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example: Dual GEMM + SwiGLU with MXFP4 weights (pre-dequant option 1).
|
| 4 |
+
Quantizes B1, B2 to MXFP4, then upcasts and runs the fused GEMM.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from amd_dual_gemm_swiglu import dual_gemm_swiglu, reference_dual_gemm_swiglu
|
| 13 |
+
from numerics_details.mxfp_details._downcast_to_mxfp import _downcast_to_mxfp
|
| 14 |
+
from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
|
| 15 |
+
|
| 16 |
+
MXFP_BLOCK_SIZE_PY = 32
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def downcast_fp16_to_mxfp4(src: torch.Tensor, block_m: int = 128, block_k: int = 64):
|
| 20 |
+
"""fp16 [M,K] -> (mx_tensor [M,K//2], mx_scale [M,K//32])."""
|
| 21 |
+
assert block_k % MXFP_BLOCK_SIZE_PY == 0
|
| 22 |
+
M, K = src.shape
|
| 23 |
+
mx_tensor = torch.empty((M, K // 2), device=src.device, dtype=torch.uint8)
|
| 24 |
+
mx_scale = torch.empty((M, K // 32), device=src.device, dtype=torch.uint8)
|
| 25 |
+
grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
|
| 26 |
+
_downcast_to_mxfp[grid](
|
| 27 |
+
mx_tensor, mx_tensor.stride(0), 1,
|
| 28 |
+
mx_scale, mx_scale.stride(0), mx_scale.stride(1),
|
| 29 |
+
src, src.stride(0), src.stride(1),
|
| 30 |
+
M, K,
|
| 31 |
+
BLOCK_SIZE_OUT_DIM=block_m,
|
| 32 |
+
BLOCK_SIZE_QUANT_DIM=block_k,
|
| 33 |
+
DEQUANT_SCALE_ROUNDING_MODE=0,
|
| 34 |
+
)
|
| 35 |
+
return mx_tensor, mx_scale
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
if not torch.cuda.is_available():
|
| 40 |
+
print("No CUDA/ROCm device.")
|
| 41 |
+
return
|
| 42 |
+
device = "cuda"
|
| 43 |
+
print("Device:", torch.cuda.get_device_name(0))
|
| 44 |
+
|
| 45 |
+
M, N, K = 256, 128, 512
|
| 46 |
+
torch.manual_seed(42)
|
| 47 |
+
a = torch.randn((M, K), device=device, dtype=torch.float16) * 0.1
|
| 48 |
+
b1_fp16 = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
|
| 49 |
+
b2_fp16 = torch.randn((K, N), device=device, dtype=torch.float16) * 0.1
|
| 50 |
+
|
| 51 |
+
# Quantize B1, B2 to MXFP4 (need K, N multiples of 32 for block_k)
|
| 52 |
+
block_k = 64
|
| 53 |
+
b1_mx, b1_scale = downcast_fp16_to_mxfp4(b1_fp16, block_k=block_k)
|
| 54 |
+
b2_mx, b2_scale = downcast_fp16_to_mxfp4(b2_fp16, block_k=block_k)
|
| 55 |
+
print(f"Quantized B1: {b1_mx.shape}, {b1_scale.shape}")
|
| 56 |
+
|
| 57 |
+
# Run dual GEMM with MXFP4 weights (pre-dequant)
|
| 58 |
+
out_mxfp = dual_gemm_swiglu(a, (b1_mx, b1_scale), (b2_mx, b2_scale))
|
| 59 |
+
print(f"Output (MXFP4 path): {out_mxfp.shape}")
|
| 60 |
+
|
| 61 |
+
# Reference with fp16
|
| 62 |
+
out_ref = reference_dual_gemm_swiglu(a.float(), b1_fp16.float(), b2_fp16.float()).to(torch.float16)
|
| 63 |
+
err = (out_mxfp.float() - out_ref.float()).abs().max().item()
|
| 64 |
+
rel = err / (out_ref.float().abs().max().item() + 1e-6)
|
| 65 |
+
print(f"vs fp16 ref: max abs err={err:.2e}, rel={rel:.2e}")
|
| 66 |
+
print("Done.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
example_moe_compaction_gemm.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example: Mask compaction + Dual GEMM integration (MoE-style).
|
| 4 |
+
Before dual GEMM: compact (Yv, Yi) per row based on BitMask.
|
| 5 |
+
Then use compacted tensors for routing into expert weights.
|
| 6 |
+
|
| 7 |
+
ROCm note: tl.store with dynamic write_indx may fail on ROCm Triton.
|
| 8 |
+
If so, use the PyTorch fallback in mask_compaction.py.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from mask_compaction import masked_compaction, masked_compaction_torch_fallback
|
| 18 |
+
from amd_dual_gemm_swiglu import dual_gemm_swiglu
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def example_integration():
|
| 22 |
+
"""Sketch: compact routing outputs, then run dual GEMM on routed experts."""
|
| 23 |
+
device = "cuda"
|
| 24 |
+
if not torch.cuda.is_available():
|
| 25 |
+
print("No GPU")
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
M, K, N = 256, 64, 128 # tokens, hidden, expert dim
|
| 29 |
+
top_k = 8
|
| 30 |
+
num_experts = 4
|
| 31 |
+
|
| 32 |
+
# Simulate routing: Yv [M, K] values, Yi [M, K] expert indices (0..num_experts-1)
|
| 33 |
+
torch.manual_seed(42)
|
| 34 |
+
Yv = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1
|
| 35 |
+
Yi = torch.randint(0, num_experts, (M, K), device=device, dtype=torch.int32)
|
| 36 |
+
|
| 37 |
+
# BitMask [M, ceil(K/32)]: 1 = use, 0 = discard (e.g. from load balance)
|
| 38 |
+
BitMask = torch.ones(M, (K + 31) // 32, device=device, dtype=torch.int32)
|
| 39 |
+
BitMask[:, 0] = 0x55555555 # example: alternating bits
|
| 40 |
+
|
| 41 |
+
# 1) Compact (Yv, Yi) per row based on BitMask
|
| 42 |
+
try:
|
| 43 |
+
RetYv, RetYi = masked_compaction(Yv, Yi, BitMask, sentinel=float("nan"))
|
| 44 |
+
print("Compaction: Triton kernel OK")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Compaction: Triton failed ({e}), using PyTorch fallback")
|
| 47 |
+
RetYv, RetYi = masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan"))
|
| 48 |
+
|
| 49 |
+
# 2) Use compacted indices for routing into expert weights
|
| 50 |
+
# Expert weights: B1[E,K,N], B2[E,K,N] or similar. For simplicity, flat GEMM:
|
| 51 |
+
# A = routed activations [M, K], B1/B2 = expert weights [K, N]
|
| 52 |
+
# This is a simplified sketch; real MoE has per-expert B.
|
| 53 |
+
B1 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
|
| 54 |
+
B2 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
|
| 55 |
+
|
| 56 |
+
# Use RetYv as activations (compacted); pad/truncate to [M, K] if needed
|
| 57 |
+
A = RetYv[:, :K].contiguous()
|
| 58 |
+
if A.shape[1] < K:
|
| 59 |
+
A = torch.nn.functional.pad(A, (0, K - A.shape[1]), value=0)
|
| 60 |
+
|
| 61 |
+
# 3) Dual GEMM
|
| 62 |
+
out = dual_gemm_swiglu(A, B1, B2)
|
| 63 |
+
print(f"Dual GEMM output: {out.shape}")
|
| 64 |
+
print("Done.")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
example_integration()
|
example_mxfp_roundtrip.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example: use _downcast_to_mxfp (fp16 -> MXFP4) and _upcast_from_mxfp (MXFP4 -> fp16)
|
| 4 |
+
for a round-trip on the remote server.
|
| 5 |
+
|
| 6 |
+
Usage on remote:
|
| 7 |
+
cd /root/kernels
|
| 8 |
+
python example_mxfp_roundtrip.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from numerics_details.mxfp_details._downcast_to_mxfp import _downcast_to_mxfp
|
| 17 |
+
from numerics_details.mxfp_details import upcast_mxfp4_to_fp16
|
| 18 |
+
|
| 19 |
+
MXFP_BLOCK_SIZE_PY = 32 # Python int for checks (tl.constexpr in kernels)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def downcast_fp16_to_mxfp4(src: torch.Tensor, block_m: int = 128, block_k: int = 64):
|
| 23 |
+
"""Convert fp16 tensor [M, K] to MXFP4 (uint8 mx_tensor + uint8 mx_scale)."""
|
| 24 |
+
assert src.dim() == 2 and src.dtype in (torch.float16, torch.bfloat16)
|
| 25 |
+
assert block_k % MXFP_BLOCK_SIZE_PY == 0, f"block_k must be multiple of {MXFP_BLOCK_SIZE_PY}"
|
| 26 |
+
M, K = src.shape
|
| 27 |
+
|
| 28 |
+
# Outputs: mx_tensor [M, K//2] uint8, mx_scale [M, K//32] uint8
|
| 29 |
+
mx_tensor = torch.empty((M, K // 2), device=src.device, dtype=torch.uint8)
|
| 30 |
+
mx_scale = torch.empty((M, K // 32), device=src.device, dtype=torch.uint8)
|
| 31 |
+
|
| 32 |
+
grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
|
| 33 |
+
_downcast_to_mxfp[grid](
|
| 34 |
+
mx_tensor, mx_tensor.stride(0), 1,
|
| 35 |
+
mx_scale, mx_scale.stride(0), mx_scale.stride(1),
|
| 36 |
+
src, src.stride(0), src.stride(1),
|
| 37 |
+
M, K,
|
| 38 |
+
BLOCK_SIZE_OUT_DIM=block_m,
|
| 39 |
+
BLOCK_SIZE_QUANT_DIM=block_k,
|
| 40 |
+
DEQUANT_SCALE_ROUNDING_MODE=0,
|
| 41 |
+
)
|
| 42 |
+
return mx_tensor, mx_scale
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
if not torch.cuda.is_available():
|
| 47 |
+
print("No CUDA/ROCm device.")
|
| 48 |
+
return
|
| 49 |
+
device = "cuda"
|
| 50 |
+
print("Device:", torch.cuda.get_device_name(0))
|
| 51 |
+
|
| 52 |
+
# Create random fp16 tensor
|
| 53 |
+
M, K = 256, 128
|
| 54 |
+
src = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1
|
| 55 |
+
|
| 56 |
+
# Downcast fp16 -> MXFP4
|
| 57 |
+
mx_tensor, mx_scale = downcast_fp16_to_mxfp4(src)
|
| 58 |
+
print(f"Downcast OK: mx_tensor {mx_tensor.shape}, mx_scale {mx_scale.shape}")
|
| 59 |
+
|
| 60 |
+
# Upcast MXFP4 -> fp16
|
| 61 |
+
recovered = upcast_mxfp4_to_fp16(mx_tensor, mx_scale)
|
| 62 |
+
print(f"Upcast OK: recovered {recovered.shape}")
|
| 63 |
+
|
| 64 |
+
# Compare
|
| 65 |
+
err = (src.float() - recovered.float()).abs().max().item()
|
| 66 |
+
rel = err / (src.float().abs().max().item() + 1e-6)
|
| 67 |
+
print(f"Round-trip max abs err: {err:.2e}, rel: {rel:.2e}")
|
| 68 |
+
print("Done.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
hardware_submission.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 1. HARDWARE DIAGNOSTICS & OS PREP
|
| 7 |
+
def check_environment():
|
| 8 |
+
cuda_avail = torch.cuda.is_available()
|
| 9 |
+
if cuda_avail:
|
| 10 |
+
# Optimization: Force kernel arguments to device to save PCIe latency
|
| 11 |
+
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
|
| 12 |
+
# Disable compiler cache for benchmarking clean runs
|
| 13 |
+
os.environ["TRITON_CACHE_DIR"] = ""
|
| 14 |
+
|
| 15 |
+
check_environment()
|
| 16 |
+
|
| 17 |
+
# 2. THE SOL KERNEL
|
| 18 |
+
@triton.autotune(
|
| 19 |
+
configs=[
|
| 20 |
+
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 21 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 22 |
+
],
|
| 23 |
+
key=['M', 'N', 'K'],
|
| 24 |
+
)
|
| 25 |
+
@triton.jit
|
| 26 |
+
def dual_gemm_hardware_kernel(
|
| 27 |
+
a_ptr, b1_ptr, b2_ptr, c_ptr,
|
| 28 |
+
sfa_ptr, sfb1_ptr, sfb2_ptr,
|
| 29 |
+
M, N, K, L,
|
| 30 |
+
stride_am, stride_ak, stride_al,
|
| 31 |
+
stride_bn, stride_bk, stride_bl,
|
| 32 |
+
stride_cm, stride_cn, stride_cl,
|
| 33 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 34 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 35 |
+
):
|
| 36 |
+
# Persistent Grid logic
|
| 37 |
+
pid = tl.program_id(0)
|
| 38 |
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 39 |
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 40 |
+
total_tiles = num_pid_m * num_pid_n * L
|
| 41 |
+
|
| 42 |
+
for tile_idx in tl.range(pid, total_tiles, tl.num_programs(0)):
|
| 43 |
+
l_idx = tile_idx // (num_pid_m * num_pid_n)
|
| 44 |
+
tile_rem = tile_idx % (num_pid_m * num_pid_n)
|
| 45 |
+
|
| 46 |
+
# Swizzle for L2 Locality
|
| 47 |
+
pid_m, pid_n = tl.swizzle2d(tile_rem, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 48 |
+
|
| 49 |
+
# Base offsets
|
| 50 |
+
rm = pid_m * BLOCK_M
|
| 51 |
+
rn = pid_n * BLOCK_N
|
| 52 |
+
|
| 53 |
+
# Ranges
|
| 54 |
+
offs_m = rm + tl.arange(0, BLOCK_M)
|
| 55 |
+
offs_n = rn + tl.arange(0, BLOCK_N)
|
| 56 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 57 |
+
|
| 58 |
+
# Memory Pointers
|
| 59 |
+
a_ptrs = a_ptr + l_idx * stride_al + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 60 |
+
b1_ptrs = b1_ptr + l_idx * stride_bl + (offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk)
|
| 61 |
+
b2_ptrs = b2_ptr + l_idx * stride_bl + (offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk)
|
| 62 |
+
|
| 63 |
+
# Accumulators
|
| 64 |
+
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 65 |
+
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 66 |
+
|
| 67 |
+
# Scale Factor Load Offsets (Assuming OCP 1 scale per 16 elements)
|
| 68 |
+
# sfa: (M, K/16, L), sfb: (N, K/16, L)
|
| 69 |
+
sfa_base = sfa_ptr + l_idx * (M * (K // 16)) + (offs_m[:, None] * (K // 16))
|
| 70 |
+
sfb1_base = sfb1_ptr + l_idx * (N * (K // 16)) + (offs_n[None, :] * (K // 16))
|
| 71 |
+
sfb2_base = sfb2_ptr + l_idx * (N * (K // 16)) + (offs_n[None, :] * (K // 16))
|
| 72 |
+
|
| 73 |
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
| 74 |
+
# 1. Load Data
|
| 75 |
+
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_K), other=0.0)
|
| 76 |
+
b1 = tl.load(b1_ptrs, mask=(offs_n[None, :] < N) & (offs_k[:, None] < K - k * BLOCK_K), other=0.0)
|
| 77 |
+
b2 = tl.load(b2_ptrs, mask=(offs_n[None, :] < N) & (offs_k[:, None] < K - k * BLOCK_K), other=0.0)
|
| 78 |
+
|
| 79 |
+
# 2. Load Scales for current K-block
|
| 80 |
+
# Blackwell uses a 32x4 atom, but for pointers, we load the K-slice
|
| 81 |
+
curr_sfa = tl.load(sfa_base + (k * (BLOCK_K // 16)), mask=(offs_m[:, None] < M), other=1.0)
|
| 82 |
+
curr_sfb1 = tl.load(sfb1_base + (k * (BLOCK_K // 16)), mask=(offs_n[None, :] < N), other=1.0)
|
| 83 |
+
curr_sfb2 = tl.load(sfb2_base + (k * (BLOCK_K // 16)), mask=(offs_n[None, :] < N), other=1.0)
|
| 84 |
+
|
| 85 |
+
# 3. Hardware DOT Scaled
|
| 86 |
+
acc1 = tl.dot_scaled(a, curr_sfa, "e2m1", b1, curr_sfb1, "e2m1", acc1)
|
| 87 |
+
acc2 = tl.dot_scaled(a, curr_sfa, "e2m1", b2, curr_sfb2, "e2m1", acc2)
|
| 88 |
+
|
| 89 |
+
# Advance K
|
| 90 |
+
a_ptrs += BLOCK_K * stride_ak
|
| 91 |
+
b1_ptrs += BLOCK_K * stride_bk
|
| 92 |
+
b2_ptrs += BLOCK_K * stride_bk
|
| 93 |
+
|
| 94 |
+
# 4. Epilogue (Fused SiLU + Gating)
|
| 95 |
+
res1 = acc1.to(tl.float16)
|
| 96 |
+
activated = res1 * tl.sigmoid(res1)
|
| 97 |
+
final_out = activated * acc2.to(tl.float16)
|
| 98 |
+
|
| 99 |
+
# 5. Masked Store
|
| 100 |
+
c_ptrs = c_ptr + l_idx * stride_cl + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
| 101 |
+
tl.store(c_ptrs, final_out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
|
| 102 |
+
|
| 103 |
+
# 3. SUBMISSION INTERFACE
|
| 104 |
+
def dual_gemm_submission(data):
|
| 105 |
+
a, b1, b2, sfa, sfb1, sfb2, c = data
|
| 106 |
+
M, K_packed, L = a.shape
|
| 107 |
+
N = b1.shape[0]
|
| 108 |
+
K = K_packed * 2 # FP4 2nd element expansion
|
| 109 |
+
|
| 110 |
+
# Saturate Device (148 for B200, 304 for MI300X)
|
| 111 |
+
num_sms = torch.cuda.get_device_properties(0).multi_processor_count
|
| 112 |
+
grid = (num_sms,)
|
| 113 |
+
|
| 114 |
+
dual_gemm_hardware_kernel[grid](
|
| 115 |
+
a, b1, b2, c, sfa, sfb1, sfb2,
|
| 116 |
+
M, N, K, L,
|
| 117 |
+
a.stride(0), a.stride(1), a.stride(2),
|
| 118 |
+
b1.stride(0), b1.stride(1), b1.stride(2),
|
| 119 |
+
c.stride(0), c.stride(1), c.stride(2)
|
| 120 |
+
)
|
| 121 |
+
return c
|
mask_compaction.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Masked compaction kernel: compact (Yv, Yi) per row based on BitMask.
|
| 3 |
+
Active elements (bit=1) move to front, inactive (bit=0) move to back with sentinel.
|
| 4 |
+
For MoE: use before dual GEMM to get dense top-k for routing into expert weights.
|
| 5 |
+
|
| 6 |
+
ROCm note: tl.store with dynamic write_indx may have limitations. If it fails,
|
| 7 |
+
fall back to PyTorch compaction.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import triton
|
| 12 |
+
import triton.language as tl
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def _masked_compaction(
|
| 17 |
+
Yv, Yi, BitMask, stride_bm, stride_bn,
|
| 18 |
+
RetYv, RetYi, sentinel, K: tl.constexpr
|
| 19 |
+
):
|
| 20 |
+
pid_m = tl.program_id(0)
|
| 21 |
+
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
|
| 22 |
+
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
|
| 23 |
+
div = yi // 32
|
| 24 |
+
rem = yi % 32
|
| 25 |
+
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
|
| 26 |
+
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
|
| 27 |
+
active_flags = active_bits.to(tl.int1)
|
| 28 |
+
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
|
| 29 |
+
write_indx = exc_cumsum + rev_arange
|
| 30 |
+
yv = tl.where(active_flags, yv, sentinel)
|
| 31 |
+
yi = tl.where(active_flags, yi, sentinel)
|
| 32 |
+
tl.store(RetYv + pid_m * K + write_indx, yv)
|
| 33 |
+
tl.store(RetYi + pid_m * K + write_indx, yi)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def masked_compaction(
|
| 37 |
+
Yv: torch.Tensor, # [M, K] values
|
| 38 |
+
Yi: torch.Tensor, # [M, K] indices (int32)
|
| 39 |
+
BitMask: torch.Tensor, # [M, ceil(K/32)] or similar - 1 bit per position
|
| 40 |
+
sentinel: float = float("nan"),
|
| 41 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
"""
|
| 43 |
+
Compact Yv, Yi per row: active (BitMask=1) to front, inactive to back with sentinel.
|
| 44 |
+
Returns (RetYv, RetYi) same shape as (Yv, Yi).
|
| 45 |
+
"""
|
| 46 |
+
M, K = Yv.shape
|
| 47 |
+
assert Yi.shape == (M, K)
|
| 48 |
+
RetYv = torch.empty_like(Yv)
|
| 49 |
+
RetYi = torch.empty_like(Yi)
|
| 50 |
+
grid = (M,)
|
| 51 |
+
_masked_compaction[grid](
|
| 52 |
+
Yv, Yi, BitMask,
|
| 53 |
+
BitMask.stride(0), BitMask.stride(1),
|
| 54 |
+
RetYv, RetYi, sentinel, K=K,
|
| 55 |
+
)
|
| 56 |
+
return RetYv, RetYi
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan")):
|
| 60 |
+
"""PyTorch fallback if Triton kernel fails on ROCm."""
|
| 61 |
+
M, K = Yv.shape
|
| 62 |
+
RetYv = torch.full_like(Yv, sentinel)
|
| 63 |
+
RetYi = torch.full_like(Yi, -1)
|
| 64 |
+
for m in range(M):
|
| 65 |
+
# Bit per position k: div=k//32, rem=k%32
|
| 66 |
+
div = torch.arange(K, device=Yv.device) // 32
|
| 67 |
+
rem = torch.arange(K, device=Yv.device) % 32
|
| 68 |
+
active = ((BitMask[m, div] >> rem) & 1).bool()
|
| 69 |
+
n_active = active.sum().item()
|
| 70 |
+
RetYv[m, :n_active] = Yv[m, active]
|
| 71 |
+
RetYi[m, :n_active] = Yi[m, active]
|
| 72 |
+
return RetYv, RetYi
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
masked_compaction_pytorch = masked_compaction_torch_fallback # alias for import
|
numerics_details/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# numerics_details: MXFP only (use mxfp_details)
|
| 2 |
+
from .mxfp_details import upcast_mxfp4_to_fp16
|
| 3 |
+
|
| 4 |
+
__all__ = ["upcast_mxfp4_to_fp16"]
|
numerics_details/mxfp_details/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mxfp_details: MXFP quantize/dequantize kernels
|
| 2 |
+
from .upcast_mxfp4 import upcast_mxfp4_to_fp16
|
| 3 |
+
|
| 4 |
+
__all__ = ["upcast_mxfp4_to_fp16"]
|
numerics_details/mxfp_details/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (240 Bytes). View file
|
|
|
numerics_details/mxfp_details/__pycache__/_downcast_to_mxfp.cpython-312.pyc
ADDED
|
Binary file (8.91 kB). View file
|
|
|
numerics_details/mxfp_details/__pycache__/_upcast_from_mxfp.cpython-312.pyc
ADDED
|
Binary file (7.88 kB). View file
|
|
|
numerics_details/mxfp_details/__pycache__/upcast_mxfp4.cpython-312.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
numerics_details/mxfp_details/_downcast_to_mxfp.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From https://huggingface.co/kernels-community/triton-kernels/blob/main/build/torch-cuda/numerics_details/mxfp_details/_downcast_to_mxfp.py
|
| 2 |
+
|
| 3 |
+
import triton
|
| 4 |
+
import triton.language as tl
|
| 5 |
+
|
| 6 |
+
# fmt: off
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
MXFP_BLOCK_SIZE = tl.constexpr(32)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@triton.jit
|
| 13 |
+
def _get_max_quant_val(dtype: tl.constexpr):
|
| 14 |
+
if dtype == tl.uint8:
|
| 15 |
+
return 6.0
|
| 16 |
+
elif dtype == tl.float8e5:
|
| 17 |
+
return 57344.0
|
| 18 |
+
elif dtype == tl.float8e4nv:
|
| 19 |
+
return 448.0
|
| 20 |
+
else:
|
| 21 |
+
tl.static_assert(False, f"Invalid {dtype=}")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.jit
|
| 25 |
+
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
|
| 26 |
+
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
|
| 27 |
+
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
| 28 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
| 29 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
| 30 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
| 31 |
+
|
| 32 |
+
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
| 33 |
+
f32_tensor = src_tensor.to(tl.float32)
|
| 34 |
+
abs_tensor = tl.abs(f32_tensor)
|
| 35 |
+
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
| 36 |
+
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 37 |
+
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
|
| 38 |
+
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
|
| 39 |
+
if DEQUANT_SCALE_ROUNDING_MODE == 0:
|
| 40 |
+
# DequantScaleRoundingMode.ROUND_UP
|
| 41 |
+
# compute 2 ** ceil(log2(dequant_scale))
|
| 42 |
+
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
|
| 43 |
+
# A corner case: exponent is 0xFF that will overflow but that's already
|
| 44 |
+
# NaN so assume we don't care.
|
| 45 |
+
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
|
| 46 |
+
else:
|
| 47 |
+
# DequantScaleRoundingMode.ROUND_DOWN
|
| 48 |
+
# compute 2 ** floor(log2(dequant_scale))
|
| 49 |
+
assert DEQUANT_SCALE_ROUNDING_MODE == 1
|
| 50 |
+
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
|
| 51 |
+
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
|
| 52 |
+
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
|
| 53 |
+
|
| 54 |
+
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 55 |
+
quant_tensor = f32_tensor * quant_scale
|
| 56 |
+
|
| 57 |
+
# Reshape the tensors after scaling
|
| 58 |
+
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 59 |
+
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
| 60 |
+
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
|
| 61 |
+
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
| 62 |
+
|
| 63 |
+
# First, we simply extract the exponent part of the scales and store the result
|
| 64 |
+
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
|
| 65 |
+
# Now we must convert the tensors to the mx format.
|
| 66 |
+
if is_fp8:
|
| 67 |
+
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
| 68 |
+
else:
|
| 69 |
+
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
| 70 |
+
signs = quant_tensor & 0x80000000
|
| 71 |
+
exponents = (quant_tensor >> 23) & 0xFF
|
| 72 |
+
mantissas = (quant_tensor & 0x7FFFFF)
|
| 73 |
+
|
| 74 |
+
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
| 75 |
+
E8_BIAS = 127
|
| 76 |
+
E2_BIAS = 1
|
| 77 |
+
# Move implicit bit 1 at the beginning to mantissa for denormals
|
| 78 |
+
# tl.core.sub not available in Triton ROCm; use plain subtraction
|
| 79 |
+
adjusted_exponents = E8_BIAS - (exponents + 1)
|
| 80 |
+
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
|
| 81 |
+
|
| 82 |
+
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
| 83 |
+
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
| 84 |
+
|
| 85 |
+
# Combine sign, exponent, and mantissa, while saturating
|
| 86 |
+
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
| 87 |
+
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
|
| 88 |
+
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
| 89 |
+
|
| 90 |
+
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
| 91 |
+
evens, odds = tl.split(e2m1_value)
|
| 92 |
+
out_tensor = evens | (odds << 4)
|
| 93 |
+
|
| 94 |
+
return out_tensor, dequant_scale_exponent
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@triton.jit
|
| 98 |
+
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
|
| 99 |
+
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
|
| 100 |
+
src_ptr, stride_src_outer, stride_src_quant,
|
| 101 |
+
outer_dim, quant_dim,
|
| 102 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
| 103 |
+
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
|
| 104 |
+
|
| 105 |
+
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
|
| 106 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
|
| 107 |
+
|
| 108 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
| 109 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
| 110 |
+
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
|
| 111 |
+
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
|
| 112 |
+
|
| 113 |
+
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
|
| 114 |
+
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
|
| 115 |
+
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16")
|
| 116 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 117 |
+
|
| 118 |
+
outer_block = tl.program_id(0).to(tl.int64)
|
| 119 |
+
quant_block = tl.program_id(1).to(tl.int64)
|
| 120 |
+
|
| 121 |
+
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
| 122 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
| 123 |
+
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
| 124 |
+
|
| 125 |
+
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
| 126 |
+
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
| 127 |
+
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
| 128 |
+
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
| 129 |
+
|
| 130 |
+
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
|
| 131 |
+
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
|
| 132 |
+
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
|
| 133 |
+
|
| 134 |
+
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
| 135 |
+
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
| 136 |
+
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
| 137 |
+
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
| 138 |
+
|
| 139 |
+
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
|
| 140 |
+
mask_n = start_out + offs_outer < outer_dim
|
| 141 |
+
full_mask_src = mask_src_quant & mask_n
|
| 142 |
+
|
| 143 |
+
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
| 144 |
+
full_mask_mxt = mask_mxt_quant & mask_n
|
| 145 |
+
|
| 146 |
+
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
| 147 |
+
full_scale_mask = scale_mask_k & mask_n
|
| 148 |
+
|
| 149 |
+
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
|
| 150 |
+
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
|
| 151 |
+
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
|
| 152 |
+
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
|
| 153 |
+
|
| 154 |
+
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
|
| 155 |
+
DEQUANT_SCALE_ROUNDING_MODE)
|
| 156 |
+
|
| 157 |
+
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
|
| 158 |
+
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
|
| 162 |
+
def _dequantize_mxfp8_fn(input, mask, pid=None):
|
| 163 |
+
return _compute_quant_and_scale(input, mask, tl.float8e4nv)
|
numerics_details/mxfp_details/_upcast_from_mxfp.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
import triton.language as tl
|
| 3 |
+
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# fmt: off
|
| 7 |
+
@triton.jit
|
| 8 |
+
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
|
| 9 |
+
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
|
| 10 |
+
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
|
| 11 |
+
|
| 12 |
+
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
|
| 13 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
|
| 14 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
| 15 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
| 16 |
+
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
|
| 17 |
+
tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16)
|
| 18 |
+
tl.static_assert(
|
| 19 |
+
mx_tensor_dtype == tl.uint8
|
| 20 |
+
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
|
| 21 |
+
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
| 22 |
+
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
| 23 |
+
|
| 24 |
+
# Determine if we are dealing with fp8 types.
|
| 25 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 26 |
+
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
| 27 |
+
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
| 28 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
| 29 |
+
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
| 30 |
+
|
| 31 |
+
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
|
| 32 |
+
outer_block = tl.program_id(0).to(tl.int64)
|
| 33 |
+
quant_block = tl.program_id(1).to(tl.int64)
|
| 34 |
+
|
| 35 |
+
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
| 36 |
+
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
| 37 |
+
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
| 38 |
+
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
| 39 |
+
|
| 40 |
+
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
|
| 41 |
+
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
|
| 42 |
+
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
|
| 43 |
+
|
| 44 |
+
# Compute offsets and masks.
|
| 45 |
+
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
| 46 |
+
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
| 47 |
+
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
| 48 |
+
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
| 49 |
+
|
| 50 |
+
mask_outer = start_out + offs_outer < outer_dim
|
| 51 |
+
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
|
| 52 |
+
full_mask_out = mask_out_quant & mask_outer
|
| 53 |
+
|
| 54 |
+
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
| 55 |
+
full_mask_src = mask_src_quant & mask_outer
|
| 56 |
+
|
| 57 |
+
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
| 58 |
+
full_scale_mask = mask_scale & mask_outer
|
| 59 |
+
|
| 60 |
+
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
|
| 61 |
+
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
|
| 62 |
+
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
|
| 63 |
+
|
| 64 |
+
# Load the packed tensor and scale.
|
| 65 |
+
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
|
| 66 |
+
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
|
| 67 |
+
|
| 68 |
+
# Upcast the scale to the destination type.
|
| 69 |
+
if dst_dtype == tl.bfloat16:
|
| 70 |
+
dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
|
| 71 |
+
else:
|
| 72 |
+
tl.static_assert(dst_dtype == tl.float16)
|
| 73 |
+
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
|
| 74 |
+
dst_scale = dst_scale.to(tl.float16)
|
| 75 |
+
|
| 76 |
+
# Now upcast the tensor.
|
| 77 |
+
if is_fp8:
|
| 78 |
+
dst_tensor = tensor.to(dst_dtype)
|
| 79 |
+
if mx_tensor_dtype == tl.float8e5:
|
| 80 |
+
from_e_bits: tl.constexpr = 5
|
| 81 |
+
from_m_bits: tl.constexpr = 2
|
| 82 |
+
to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5
|
| 83 |
+
to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
|
| 84 |
+
|
| 85 |
+
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
|
| 86 |
+
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
|
| 87 |
+
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
|
| 88 |
+
dst_tensor = tl.where(
|
| 89 |
+
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
|
| 90 |
+
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True),
|
| 91 |
+
dst_tensor,
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
tl.static_assert(is_fp4)
|
| 95 |
+
dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15
|
| 96 |
+
dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800
|
| 97 |
+
dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
|
| 98 |
+
# e2m1
|
| 99 |
+
em0 = tensor & 0x07
|
| 100 |
+
em1 = tensor & 0x70
|
| 101 |
+
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
|
| 102 |
+
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
|
| 103 |
+
# Three cases:
|
| 104 |
+
# 1) x is normal and non-zero: Correct bias
|
| 105 |
+
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
| 106 |
+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
| 107 |
+
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
| 108 |
+
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
| 109 |
+
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
| 110 |
+
# 3) x is zero, do nothing
|
| 111 |
+
# Interleave x0,x1: use tl.where (ROCm tl.cat only supports 1D)
|
| 112 |
+
idx_k = tl.arange(0, BLOCK_SIZE_QUANT_DIM)
|
| 113 |
+
is_even = (idx_k % 2) == 0
|
| 114 |
+
val_x0 = x0[:, idx_k // 2]
|
| 115 |
+
val_x1 = x1[:, idx_k // 2]
|
| 116 |
+
dst_tensor = tl.where(is_even[None, :], val_x0, val_x1).to(dst_dtype, bitcast=True)
|
| 117 |
+
|
| 118 |
+
# dst_tensor already [M, K/32, 32] for fp4; scale was stored with 32-sized "inner" grouping
|
| 119 |
+
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
| 120 |
+
scale = scale.reshape(dst_scale.shape)
|
| 121 |
+
|
| 122 |
+
out_tensor = dst_tensor * dst_scale
|
| 123 |
+
# Correct any NaNs encoded via the scale.
|
| 124 |
+
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
|
| 125 |
+
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 126 |
+
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
|
numerics_details/mxfp_details/upcast_mxfp4.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reusable MXFP4 upcast: MXFP4 (uint8 mx_tensor + uint8 mx_scale) -> fp16/bf16.
|
| 3 |
+
Uses Triton kernel when available; falls back to PyTorch on ROCm (tl.cat limitation).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ._upcast_from_mxfp import _upcast_from_mxfp
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from triton.compiler.errors import CompilationError
|
| 12 |
+
except ImportError:
|
| 13 |
+
CompilationError = Exception
|
| 14 |
+
|
| 15 |
+
MXFP_BLOCK_SIZE_PY = 32
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _upcast_mxfp4_to_fp16_pytorch(
|
| 19 |
+
mx_tensor: torch.Tensor, mx_scale: torch.Tensor, dtype: torch.dtype = torch.float16
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
"""PyTorch fallback (used when Triton kernel fails on ROCm)."""
|
| 22 |
+
M, K_half = mx_tensor.shape
|
| 23 |
+
K = K_half * 2
|
| 24 |
+
dst_bias = 15
|
| 25 |
+
dst_0p5 = 0x3800
|
| 26 |
+
dst_m_bits = 10
|
| 27 |
+
|
| 28 |
+
tensor = mx_tensor.to(torch.int32)
|
| 29 |
+
em0 = tensor & 0x07
|
| 30 |
+
em1 = tensor & 0x70
|
| 31 |
+
x0 = (em0 << (dst_m_bits - 1)) | ((tensor & 0x08) << 12)
|
| 32 |
+
x1 = (em1 << (dst_m_bits - 5)) | ((tensor & 0x80) << 8)
|
| 33 |
+
|
| 34 |
+
x0 = torch.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
| 35 |
+
x1 = torch.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
| 36 |
+
x0 = torch.where(em0 == 0x01, torch.full_like(x0, dst_0p5) | (x0 & 0x8000), x0)
|
| 37 |
+
x1 = torch.where(em1 == 0x10, torch.full_like(x1, dst_0p5) | (x1 & 0x8000), x1)
|
| 38 |
+
|
| 39 |
+
out_u16 = torch.empty((M, K), device=mx_tensor.device, dtype=torch.uint16)
|
| 40 |
+
out_u16[:, 0::2] = (x0 & 0xFFFF).to(torch.uint16)
|
| 41 |
+
out_u16[:, 1::2] = (x1 & 0xFFFF).to(torch.uint16)
|
| 42 |
+
dst_tensor = out_u16.view(dtype)
|
| 43 |
+
|
| 44 |
+
scale_u32 = mx_scale.to(torch.int32) << 23
|
| 45 |
+
dst_scale = scale_u32.view(torch.float32).to(dtype)
|
| 46 |
+
dst_scale = dst_scale.unsqueeze(-1).repeat(1, 1, 32).reshape(M, K)
|
| 47 |
+
|
| 48 |
+
out_tensor = dst_tensor * dst_scale
|
| 49 |
+
out_tensor = torch.where(
|
| 50 |
+
mx_scale.unsqueeze(-1).expand(-1, -1, 32).reshape(M, K) == 0xFF,
|
| 51 |
+
float("nan"),
|
| 52 |
+
out_tensor,
|
| 53 |
+
)
|
| 54 |
+
return out_tensor
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def upcast_mxfp4_to_fp16(
|
| 58 |
+
mx_tensor: torch.Tensor,
|
| 59 |
+
mx_scale: torch.Tensor,
|
| 60 |
+
block_m: int = 128,
|
| 61 |
+
block_k: int = 64,
|
| 62 |
+
dtype: torch.dtype = torch.float16,
|
| 63 |
+
verbose: bool = False,
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
"""Convert MXFP4 [M,K/2]+[M,K/32] -> fp16/bf16 [M,K]. Falls back to PyTorch if Triton fails."""
|
| 66 |
+
assert mx_tensor.dim() == 2 and mx_tensor.dtype == torch.uint8
|
| 67 |
+
assert mx_scale.dim() == 2 and mx_scale.dtype == torch.uint8
|
| 68 |
+
M = mx_tensor.shape[0]
|
| 69 |
+
K = mx_tensor.shape[1] * 2
|
| 70 |
+
assert mx_scale.shape == (M, K // 32)
|
| 71 |
+
assert block_k % MXFP_BLOCK_SIZE_PY == 0
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
out = torch.empty((M, K), device=mx_tensor.device, dtype=dtype)
|
| 75 |
+
grid = ((M + block_m - 1) // block_m, (K + block_k - 1) // block_k)
|
| 76 |
+
_upcast_from_mxfp[grid](
|
| 77 |
+
out, out.stride(0), 1,
|
| 78 |
+
mx_scale, mx_scale.stride(0), mx_scale.stride(1),
|
| 79 |
+
mx_tensor, mx_tensor.stride(0), 1,
|
| 80 |
+
M, K,
|
| 81 |
+
BLOCK_SIZE_OUT_DIM=block_m,
|
| 82 |
+
BLOCK_SIZE_QUANT_DIM=block_k,
|
| 83 |
+
)
|
| 84 |
+
return out
|
| 85 |
+
except CompilationError:
|
| 86 |
+
if verbose:
|
| 87 |
+
print("Triton upcast failed (e.g. ROCm), using PyTorch fallback.")
|
| 88 |
+
return _upcast_mxfp4_to_fp16_pytorch(mx_tensor, mx_scale, dtype)
|
submission.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
# 1. HARDWARE DIAGNOSTICS
|
| 6 |
+
def check_environment():
|
| 7 |
+
print(f"--- Environment Check ---")
|
| 8 |
+
cuda_avail = torch.cuda.is_available()
|
| 9 |
+
print(f"Is CUDA/ROCm available? {cuda_avail}")
|
| 10 |
+
|
| 11 |
+
if cuda_avail:
|
| 12 |
+
device_name = torch.cuda.get_device_name(0)
|
| 13 |
+
print(f"GPU Detected: {device_name}")
|
| 14 |
+
|
| 15 |
+
# Check for Blackwell (SM 10.0) or MI300X (gfx942)
|
| 16 |
+
prop = torch.cuda.get_device_properties(0)
|
| 17 |
+
if hasattr(prop, 'major'):
|
| 18 |
+
print(f"Compute Capability: {prop.major}.{prop.minor}")
|
| 19 |
+
else:
|
| 20 |
+
print("No NVIDIA/AMD GPU detected. Triton kernels will not run on this hardware.")
|
| 21 |
+
print(f"-------------------------\n")
|
| 22 |
+
|
| 23 |
+
# Call diagnostic immediately on import
|
| 24 |
+
check_environment()
|
| 25 |
+
|
| 26 |
+
# 2. PLACEHOLDER KERNEL (Logic from previous steps)
|
| 27 |
+
@triton.jit
|
| 28 |
+
def dual_gemm_kernel(a_ptr, b1_ptr, b2_ptr, c_ptr, M, N, K, **meta):
|
| 29 |
+
# Kernel code here...
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
# 3. HARNESS INTERFACE
|
| 33 |
+
def dual_gemm_submission(data):
|
| 34 |
+
# This is what the leaderboard/benchmark calls
|
| 35 |
+
a, b1, b2, sfa, sfb1, sfb2, c = data
|
| 36 |
+
# ... launch logic ...
|
| 37 |
+
return c
|
test_mxfp.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Minimal test for MXFP _downcast_to_mxfp on MI300X (ROCm).
|
| 4 |
+
Tests fp16 -> uint8 (fp4 packed) path; float8 path may not work on ROCm yet.
|
| 5 |
+
Run on remote: cd /root/kernels && python test_mxfp.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Allow imports from /root/kernels
|
| 12 |
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
if _script_dir not in sys.path:
|
| 14 |
+
sys.path.insert(0, _script_dir)
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
print("=== MXFP Import Test ===")
|
| 18 |
+
try:
|
| 19 |
+
from numerics_details.mxfp_details._downcast_to_mxfp import (
|
| 20 |
+
_downcast_to_mxfp,
|
| 21 |
+
_compute_quant_and_scale,
|
| 22 |
+
MXFP_BLOCK_SIZE,
|
| 23 |
+
)
|
| 24 |
+
print(" Import OK: _downcast_to_mxfp, _compute_quant_and_scale, MXFP_BLOCK_SIZE")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f" Import FAILED: {e}")
|
| 27 |
+
return 1
|
| 28 |
+
|
| 29 |
+
print("\n=== Triton + CUDA/ROCm Check ===")
|
| 30 |
+
import torch
|
| 31 |
+
if not torch.cuda.is_available():
|
| 32 |
+
print(" No GPU available. Skipping kernel test.")
|
| 33 |
+
return 0
|
| 34 |
+
print(f" Device: {torch.cuda.get_device_name(0)}")
|
| 35 |
+
|
| 36 |
+
import triton
|
| 37 |
+
import triton.language as tl
|
| 38 |
+
|
| 39 |
+
print("\n=== MXFP Downcast Test (fp16 -> fp4 uint8) ===")
|
| 40 |
+
# Use fp4 path (uint8 output) - avoids float8 dtypes which may lack ROCm support
|
| 41 |
+
BLOCK_SIZE_OUT_DIM = 64
|
| 42 |
+
BLOCK_SIZE_QUANT_DIM = 64 # must be multiple of 32
|
| 43 |
+
outer_dim = 128
|
| 44 |
+
quant_dim = 128
|
| 45 |
+
DEQUANT_SCALE_ROUNDING_MODE = 0
|
| 46 |
+
|
| 47 |
+
device = "cuda"
|
| 48 |
+
src = torch.randn(outer_dim, quant_dim, device=device, dtype=torch.float16) * 0.1
|
| 49 |
+
|
| 50 |
+
# Output shapes for fp4 (uint8): mx_tensor [outer, quant//2], mx_scale [outer, quant//32]
|
| 51 |
+
mx_tensor = torch.empty(outer_dim, quant_dim // 2, device=device, dtype=torch.uint8)
|
| 52 |
+
mx_scale = torch.empty(outer_dim, quant_dim // 32, device=device, dtype=torch.uint8)
|
| 53 |
+
|
| 54 |
+
num_outer_blocks = (outer_dim + BLOCK_SIZE_OUT_DIM - 1) // BLOCK_SIZE_OUT_DIM
|
| 55 |
+
num_quant_blocks = (quant_dim + BLOCK_SIZE_QUANT_DIM - 1) // BLOCK_SIZE_QUANT_DIM
|
| 56 |
+
grid = (num_outer_blocks, num_quant_blocks)
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
_downcast_to_mxfp[grid](
|
| 60 |
+
mx_tensor,
|
| 61 |
+
src.stride(0), 1,
|
| 62 |
+
mx_scale,
|
| 63 |
+
mx_scale.stride(0), mx_scale.stride(1),
|
| 64 |
+
src,
|
| 65 |
+
src.stride(0), src.stride(1),
|
| 66 |
+
outer_dim, quant_dim,
|
| 67 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
|
| 68 |
+
BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
|
| 69 |
+
DEQUANT_SCALE_ROUNDING_MODE=DEQUANT_SCALE_ROUNDING_MODE,
|
| 70 |
+
)
|
| 71 |
+
torch.cuda.synchronize()
|
| 72 |
+
print(" Kernel launch OK")
|
| 73 |
+
print(f" mx_tensor shape: {mx_tensor.shape}, dtype: {mx_tensor.dtype}")
|
| 74 |
+
print(f" mx_scale shape: {mx_scale.shape}")
|
| 75 |
+
print(f" mx_tensor sample (first row): {mx_tensor[0, :8].tolist()}")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f" Kernel FAILED: {e}")
|
| 78 |
+
import traceback
|
| 79 |
+
traceback.print_exc()
|
| 80 |
+
return 1
|
| 81 |
+
|
| 82 |
+
print("\n=== Done ===")
|
| 83 |
+
return 0
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
sys.exit(main())
|
testing.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Testing utilities matching triton-kernels/build/torch-cuda/testing.py.
|
| 3 |
+
https://huggingface.co/kernels-community/triton-kernels/blob/main/build/torch-cuda/testing.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import enum
|
| 7 |
+
import functools
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# Numerics constants - use triton_kernels.numerics when available
|
| 15 |
+
try:
|
| 16 |
+
from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
| 17 |
+
except ImportError:
|
| 18 |
+
# Standalone fallback: standard float8 max finite values
|
| 19 |
+
MAX_FINITE_FLOAT8E5 = 57344.0 # float8 e5m2
|
| 20 |
+
MAX_FINITE_FLOAT8E4NV = 448.0 # float8 e4m3fn
|
| 21 |
+
MAX_FINITE_FLOAT8E4B8 = 448.0 # float8 e4m3fnuz
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def assert_equal(ref, tri):
|
| 25 |
+
if isinstance(ref, torch.Tensor):
|
| 26 |
+
assert torch.all(ref == tri)
|
| 27 |
+
else:
|
| 28 |
+
assert ref == tri
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
|
| 32 |
+
"""
|
| 33 |
+
Compare reference values against obtained values.
|
| 34 |
+
"""
|
| 35 |
+
if tri.dtype.itemsize == 1:
|
| 36 |
+
ref_as_type = ref.to(tri.dtype)
|
| 37 |
+
if ref.dtype == tri.dtype:
|
| 38 |
+
assert torch.all(ref_as_type == tri)
|
| 39 |
+
return
|
| 40 |
+
ref = ref_as_type
|
| 41 |
+
|
| 42 |
+
if maxtol is None:
|
| 43 |
+
maxtol = 2e-2
|
| 44 |
+
if rmstol is None:
|
| 45 |
+
rmstol = 4e-3
|
| 46 |
+
|
| 47 |
+
# cast to float32:
|
| 48 |
+
ref = ref.to(torch.float32).detach()
|
| 49 |
+
tri = tri.to(torch.float32).detach()
|
| 50 |
+
assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
|
| 51 |
+
|
| 52 |
+
# deal with infinite elements:
|
| 53 |
+
inf_mask_ref = torch.isinf(ref)
|
| 54 |
+
inf_mask_tri = torch.isinf(tri)
|
| 55 |
+
assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
|
| 56 |
+
refn = torch.where(inf_mask_ref, 0, ref)
|
| 57 |
+
trin = torch.where(inf_mask_tri, 0, tri)
|
| 58 |
+
|
| 59 |
+
# normalise so that RMS calculation doesn't overflow:
|
| 60 |
+
eps = 1.0e-30
|
| 61 |
+
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
|
| 62 |
+
refn *= multiplier
|
| 63 |
+
trin *= multiplier
|
| 64 |
+
|
| 65 |
+
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
|
| 66 |
+
|
| 67 |
+
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
|
| 68 |
+
max_err = torch.max(rel_err).item()
|
| 69 |
+
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
|
| 70 |
+
|
| 71 |
+
if verbose:
|
| 72 |
+
print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
|
| 73 |
+
print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
|
| 74 |
+
|
| 75 |
+
if max_err > maxtol:
|
| 76 |
+
bad_idxs = torch.nonzero(rel_err > maxtol)
|
| 77 |
+
num_nonzero = bad_idxs.size(0)
|
| 78 |
+
bad_idxs = bad_idxs[:1000]
|
| 79 |
+
print("%d / %d mismatched elements (shape = %s) at coords %s" %
|
| 80 |
+
(num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
|
| 81 |
+
|
| 82 |
+
bad_idxs = bad_idxs.unbind(-1)
|
| 83 |
+
print("ref values: ", ref[tuple(bad_idxs)].cpu())
|
| 84 |
+
print("tri values: ", tri[tuple(bad_idxs)].cpu())
|
| 85 |
+
|
| 86 |
+
assert max_err <= maxtol
|
| 87 |
+
assert rms_err <= rmstol
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ComputeSanitizerTool(enum.Enum):
|
| 91 |
+
MEMCHECK = "memcheck"
|
| 92 |
+
RACECHECK = "racecheck"
|
| 93 |
+
SYNCCHECK = "synccheck"
|
| 94 |
+
INITCHECK = "initcheck"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compute_sanitizer(**target_kwargs):
|
| 98 |
+
"""
|
| 99 |
+
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
|
| 100 |
+
to expose potential memory access errors.
|
| 101 |
+
This decorator requires the `request` fixture to be present.
|
| 102 |
+
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
|
| 103 |
+
Running tests under compute sanitizer requires launching subprocess and is slow,
|
| 104 |
+
so use sparingly
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def decorator(test_fn):
|
| 108 |
+
|
| 109 |
+
@functools.wraps(test_fn)
|
| 110 |
+
def wrapper(*args, **kwargs):
|
| 111 |
+
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
|
| 112 |
+
test_fn(*args, **kwargs)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
import psutil
|
| 116 |
+
|
| 117 |
+
if target_kwargs.pop("clear_torch_cache", False):
|
| 118 |
+
# If we don't pop clear_torch_cache, it won't pass
|
| 119 |
+
# target_kwargs.items() <= kwargs.items() condition below.
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
+
tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
|
| 122 |
+
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
|
| 123 |
+
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
|
| 124 |
+
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
|
| 125 |
+
|
| 126 |
+
ppid_name = psutil.Process(os.getppid()).exe()
|
| 127 |
+
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
|
| 128 |
+
if "run_sanitizer" in kwargs:
|
| 129 |
+
run_compute_sanitizer &= kwargs["run_sanitizer"]
|
| 130 |
+
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
|
| 131 |
+
for tool in tools_to_check:
|
| 132 |
+
path = os.path.realpath(test_fn.__globals__["__file__"])
|
| 133 |
+
# get path of current file
|
| 134 |
+
env = {
|
| 135 |
+
"PATH": os.environ["PATH"],
|
| 136 |
+
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
|
| 137 |
+
"TORCH_SHOW_CPP_STACKTRACES": "1",
|
| 138 |
+
"CUDA_LAUNCH_BLOCKING": "1",
|
| 139 |
+
}
|
| 140 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
| 141 |
+
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
|
| 142 |
+
assert "request_fixture" in kwargs, (
|
| 143 |
+
"memcheck'ed test must have a (possibly unused) `request` fixture")
|
| 144 |
+
test_id = kwargs["request_fixture"].node.callspec.id
|
| 145 |
+
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
|
| 146 |
+
cmd = [
|
| 147 |
+
"compute-sanitizer",
|
| 148 |
+
"--target-processes=application-only",
|
| 149 |
+
"--destroy-on-device-error=context",
|
| 150 |
+
f"--tool={tool.value}",
|
| 151 |
+
sys.executable,
|
| 152 |
+
"-m",
|
| 153 |
+
"pytest",
|
| 154 |
+
"-vsx",
|
| 155 |
+
cmd,
|
| 156 |
+
]
|
| 157 |
+
for opt in ["--update_checksum", "--ignore_checksum_error"]:
|
| 158 |
+
if opt in sys.argv:
|
| 159 |
+
cmd.append(opt)
|
| 160 |
+
out = subprocess.run(
|
| 161 |
+
cmd,
|
| 162 |
+
stdout=subprocess.PIPE,
|
| 163 |
+
stderr=subprocess.STDOUT,
|
| 164 |
+
env=env,
|
| 165 |
+
)
|
| 166 |
+
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
|
| 167 |
+
out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
|
| 168 |
+
test_output = out.stdout
|
| 169 |
+
if type(test_output) is bytes:
|
| 170 |
+
test_output = test_output.decode()
|
| 171 |
+
|
| 172 |
+
fail = False
|
| 173 |
+
if not sanitizer_ok:
|
| 174 |
+
print("compute-sanitizer returned an error")
|
| 175 |
+
fail = True
|
| 176 |
+
elif out.returncode != 0:
|
| 177 |
+
print(
|
| 178 |
+
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
|
| 179 |
+
)
|
| 180 |
+
print(f"{out.returncode=}")
|
| 181 |
+
fail = True
|
| 182 |
+
|
| 183 |
+
if fail:
|
| 184 |
+
print("*****************************************************")
|
| 185 |
+
print("******************** TEST OUTPUT ********************")
|
| 186 |
+
print("*****************************************************")
|
| 187 |
+
print(test_output)
|
| 188 |
+
print("*****************************************************")
|
| 189 |
+
print("****************** TEST OUTPUT END ******************")
|
| 190 |
+
print("*****************************************************")
|
| 191 |
+
assert None
|
| 192 |
+
else:
|
| 193 |
+
test_fn(*args, **kwargs)
|
| 194 |
+
|
| 195 |
+
return wrapper
|
| 196 |
+
|
| 197 |
+
return decorator
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def compute_actual_scale(x, dtype):
|
| 201 |
+
max_finite = {
|
| 202 |
+
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
|
| 203 |
+
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
|
| 204 |
+
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
|
| 205 |
+
}[dtype]
|
| 206 |
+
return x.abs().max() / max_finite
|