Compare commits

..

6 Commits

Author SHA1 Message Date
32632d57aa tc 2025-10-28 14:51:22 -07:00
551921d484 Change t.is_cuda to t.device.type == 'cuda' in torch/utils/viz (#156418)
Fixes #156417

Unlike `.is_cuda` the property `.device` is supported by `ShardedTensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156418
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
2025-10-28 20:34:14 +00:00
b5189e269e NVFP4 grouped gemm support via. FBGEMM kernels (#166308)
Summary:

* Add NVFP4 (1x16 block e4m3, tensor-wise fp32) scaled grouped gemm
* Extend testing to add nvfp4 support

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166308
Approved by: https://github.com/ngimel
2025-10-28 20:32:53 +00:00
3895ce093f [inductor] add in-kernel nan-check (#166008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166008
Approved by: https://github.com/eellison
2025-10-28 20:19:10 +00:00
8aa087a29d [ez] Fix print for failing test when entire file fails (#166420)
Was previously printing "FAILED CONSISTENTLY: ul" since it was null,
This changes it so it prints the test_file by moving some logic for checking this to be earlier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166420
Approved by: https://github.com/Skylion007
2025-10-28 20:13:58 +00:00
7379972cc0 Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit cdb60e44eb528bf02c6bb2d7e384298283e755ca.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/xmfan due to Compile time regression ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3458252331))
2025-10-28 20:01:54 +00:00
12 changed files with 491 additions and 211 deletions

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
set(fbgemm_genai_cuh
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)

View File

@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
#endif
}
Tensor&
_f4_f4_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& global_scale_a,
const Tensor& scale_b,
const Tensor& global_scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
Tensor& out) {
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
// Typing checks
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out,
global_scale_a.mul(global_scale_b)
);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
#endif
return out;
}
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// if {mx,nv}fp4, will need to modify K later
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
int blocksize = 32;
// check for nvfp4 vs. mxfp4 to fix blocksize
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
blocksize = 16;
}
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
if (is_fp4) {
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
// reported K. Reverse that adjustment.
const int fp4_elems_per_byte = 2;
K *= fp4_elems_per_byte;
}
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_K = round_up(K/blocksize, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
"but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else if (using_mx) {
_check_scales_blocked(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
@ -411,9 +471,10 @@ namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
} // anonymous namespace
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
offs.value(),
out);
}
case ScaledGemmImplementation::NVFP4_NVFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _f4_f4_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0], /* block-scale A */
scale_a[1], /* global-scale A */
scale_b[0], /* block-scale B */
scale_b[1], /* global-scale B */
offs.value(),
std::nullopt, /* bias */
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");

View File

@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue("'enable_fp_fusion': False" in code)
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
@requires_cuda_and_triton
@config.patch(runtime_triton_nan_asserts=True)
def test_nan_assert_inside_triton_kernel(self):
def fn(x):
x = x - 1
# Uncomment the following line can trigger the failure of
# the device size assertion
# x = torch.log(x)
return torch.where(x.isnan(), 3.14, x)
compiled = torch.compile(fn)
x = torch.randn(4096, device=GPU_TYPE)
out, (code,) = run_and_get_code(compiled, x)
self.assertTrue("'NaN or Inf found'" in code)
torch.testing.assert_close(out, fn(x))
@skip_if_cpp_wrapper("skip cpp wrapper")
@requires_cuda_and_triton
def test_repeat_interleave_decomposition_has_clamp(self):

View File

@ -27,7 +27,6 @@ import torch
import torch.distributed as dist
from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import (
get_report_dir,
get_report_path,
IS_CI,
IS_MACOS,
@ -35,6 +34,7 @@ from torch.testing._internal.common_utils import (
set_cwd,
shell,
TEST_CUDA,
TEST_SAVE_XML,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TEST_WITH_SLOW_GRADCHECK,
@ -529,14 +529,6 @@ def run_test(
replacement = {"-f": "-x", "-dist=loadfile": "--dist=loadfile"}
unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
xml_report_dir = get_report_dir(test_file, None, options.pytest)
if is_cpp_test:
unittest_args.append(
f"--junit-xml-reruns={get_report_path(xml_report_dir, test_file)}"
)
else:
unittest_args.append(f"--save-xml={xml_report_dir}")
if options.showlocals:
if options.pytest:
unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
@ -763,6 +755,8 @@ def run_test_retries(
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
) as f:
current_failure = f.read()
if current_failure == "null":
current_failure = f"'{test_file}'"
except FileNotFoundError:
print_to_file(
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
@ -799,8 +793,6 @@ def run_test_retries(
print_to_file("Retrying single test...")
print_items = [] # do not continue printing them, massive waste of space
if "null" in num_failures:
num_failures[f"'{test_file}'"] = num_failures.pop("null")
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
if len(flaky_failures) > 0:
@ -1234,6 +1226,12 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
# is much slower than running them directly
pytest_args.extend(["-n", str(NUM_PROCS)])
if TEST_SAVE_XML:
# Add the option to generate XML test report here as C++ tests
# won't go into common_utils
test_report_path = get_report_path(pytest=True)
pytest_args.extend(["--junit-xml-reruns", test_report_path])
if options.pytest_k_expr:
pytest_args.extend(["-k", options.pytest_k_expr])

View File

@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
_floatx_unpacked_to_f32,
ceil_div, to_blocked,
to_mxfp8,
from_blocked_format,
generate_jagged_offs,
)
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
uint8_data = uint8_data.contiguous().view(-1)
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
def unpack_uint4(uint8_data) -> torch.Tensor:
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
# a tensor twice as wide. Useful for dequant operations.
shape = list(uint8_data.shape)
# 2x packed elements -> single non-packed => adjust shape
shape[-1] *= 2
out = torch.empty(
*shape,
device=uint8_data.device,
dtype=torch.uint8
).view(-1)
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
out[1::2] = uint8_data_as_uint8[:] >> 4
out[::2] = uint8_data_as_uint8 & 15
return out.view(shape)
def _bfloat16_to_float4_e2m1fn_x2(x):
assert x.dtype == torch.bfloat16
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
x = x.view(torch.float4_e2m1fn_x2)
return x
def _convert_to_nvfp4_with_hp_ref(t):
# Convert a tensor to nvfp4, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : nvfp4 tensor (2x elements packed into uint8)
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
# t_global_scale: fp32 tensor-wise global scaling factor
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
t,
16,
)
t_hp = from_blocked_format(
_floatx_unpacked_to_f32(
unpack_uint4(t_lp),
FP4_EBITS,
FP4_MBITS),
t_scale,
blocksize=16) * t_global_scale
return t_hp, t_lp, t_scale, t_global_scale
def _convert_to_mxfp8_with_hp_ref(t):
# Convert a tensor to mxfp8, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : fp8_e4m3 tensor
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
t_scale, t_lp = to_mxfp8(t)
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
return t_hp, t_lp, t_scale
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
# Convert scales to blocked format. either mxfp8 or nvfp4
th_list = []
t_list = []
t_blocked_scale_list = []
t_global_scale_list = []
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else offs[group_idx - 1]
)
curr_group_end_offset = offs[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
t_slice = t[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
if format == 'mxfp8':
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
elif format == 'nvfp4':
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
t_slice,
)
t_global_scale_list.append(tq_global)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
t_list.append(tq_slice)
th_list.append(th_slice)
# Convert scales to blocked format.
t_scale_slice_blocked = to_blocked(
t_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
t_blocked_scale_list.append(t_scale_slice_blocked)
# Assemble the full XQ and WQ
tq = torch.cat(t_list, dim=1).contiguous()
th = torch.cat(th_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
MN_rounded = round_up(MN, 128)
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
# Global scales only exist for nvfp4
t_global_scales = None
if len(t_global_scale_list) > 0:
t_global_scales = torch.stack(t_global_scale_list)
return th, tq, t_blocked_scales, t_global_scales
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
# Build some standard args that are wordy
# Note: if/when ROCm support added, need to change swizzle handling
kwargs = {
'mxfp8': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': ScalingType.BlockWise1x32,
'scale_recipe_b': ScalingType.BlockWise1x32,
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
'nvfp4': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
}
return kwargs[format]
class TestFP8Matmul(TestCase):
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [2048, 2049])
@parametrize("N", [8192])
@parametrize("K", [16640])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
torch.manual_seed(42)
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs(
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
# Convert scales to blocked format.
x_list = []
w_list = []
x_blocked_scale_list = []
w_blocked_scale_list = []
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
X, M, G, input_group_end_offsets, format=format
)
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
W, N, G, input_group_end_offsets, format=format
)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
curr_group_end_offset = input_group_end_offsets[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
x_slice = X[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
w_slice = W[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (N, K_group)
x_scale_slice, xq_slice = to_mxfp8(
x_slice
) # scale shape -> (M, K_group // 32)
w_scale_slice, wq_slice = to_mxfp8(
w_slice
) # scale shape -> (N, K_group // 32)
x_list.append(xq_slice)
w_list.append(wq_slice)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Convert scales to blocked format.
x_scale_slice_blocked = to_blocked(
x_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
w_scale_slice_blocked = to_blocked(
w_scale_slice
) # (round_up(N, 128), round_up(K_group//32, 4))
x_blocked_scale_list.append(x_scale_slice_blocked)
w_blocked_scale_list.append(w_scale_slice_blocked)
# Assemble the full XQ and WQ
xq = torch.cat(x_list, dim=1).contiguous()
wq = torch.cat(w_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
M_rounded = round_up(M, 128)
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
# Combine all WQ groups blocked scales into one tensor.
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
N_rounded = round_up(N, 128)
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute mxfp8 grouped mm output
y_mxfp8 = scaled_grouped_mm_wrap(
xq, # (M, total_K)
wq.transpose(-2, -1), # (total_K, N)
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
offs=input_group_end_offsets, # (G,)
out_dtype=torch.bfloat16,
wrap_v2=wrap_v2
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
**kwargs,
)
# bf16 reference output
y_bf16 = torch._grouped_mm(
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
)
# Assert no NaNs
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
# Assert outputs are close
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [16640])
@parametrize("N", [8192])
@parametrize("K", [4096])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
torch.manual_seed(42)
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights.
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(W[i])
w_scale = to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
def _3d_to_blocked_scaled(W, G, format):
wh_list = []
wq_list = []
w_scale_list = []
w_global_scale_list = []
for i in range(G):
if format == "mxfp8":
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
elif format == "nvfp4":
w_scale, wq = to_mxfp8(W[i])
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
w_global_scale_list.append(w_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Swizzle scaled
# TODO(slayton): gate on cuda/hip
w_scale = to_blocked(w_scale)
wh_list.append(wh)
wq_list.append(wq)
w_scale_list.append(w_scale)
wh = torch.stack(wh_list, dim=0).contiguous()
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Global scales only exist for nvfp4
if len(w_global_scale_list) > 0:
w_global_scales = torch.stack(w_global_scale_list)
else:
w_global_scales = None
return wh, wq, w_scale, w_global_scales
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
def _2d_to_blocked_scaled(X, K, G, offs, format):
xh_list = []
xq_list = []
x_scale_list = []
x_global_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
if format == "mxfp8":
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
elif format == "nvfp4":
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
x_global_scale_list.append(x_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Compute mxfp8 grouped gemm.
y_mxfp8 = scaled_grouped_mm_wrap(
x_scale = to_blocked(x_scale)
xh_list.append(xh)
xq_list.append(xq)
x_scale_list.append(x_scale)
xh = torch.cat(xh_list, dim=0).contiguous()
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
xh = xh.view(-1, xh.shape[-1])
x_global_scales = None
if len(x_global_scale_list) > 0:
x_global_scales = torch.stack(x_global_scale_list)
return xh, xq, x_scale, x_global_scales
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute low-precision grouped gemm.
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
x_scale,
w_scale,
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
wrap_v2=wrap_v2)
**kwargs
)
# Compute reference bf16 grouped gemm.
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
y_bf16 = torch._grouped_mm(
X,
W.transpose(-2, -1),
xh,
wh.transpose(-2, -1),
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
)
# Assert outputs are close.
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
@parametrize("fast_accum", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
# AMD does not support NVFP4
@parametrize("wrap_v2", [True, False])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
device = "cuda"

View File

@ -720,13 +720,22 @@ def check_shape(
) -> None:
backend = get_current_backend()
assert shape is not None
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
shape_str = (
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
)
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
backend = get_current_backend()
if backend == "triton":
msg = "NaN or Inf found"
buffer.writeline(
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
)
class DataTypePropagation:
def __init__(self, body: LoopBody) -> None:
self.body = body
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
assert output_shape is not None
check_shape(V.kernel.compute, csevar, output_shape)
if config.runtime_triton_nan_asserts:
check_nan(V.kernel.compute, csevar)
return csevar
return pytree.tree_map(do_cse, value)

View File

@ -626,7 +626,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -206,6 +206,9 @@ static_weight_shapes = True
# put correctness assertions in generated code
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
runtime_triton_nan_asserts = (
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
)
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
# Disable by default in fbcode

View File

@ -3550,24 +3550,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not inductor_meta.get("autotune_pointwise", True) and not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
def ceil_div(a, b):
return (a + b - 1) // b
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
# more naturally.
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
# Output should be [input.size(0, input.size(1) // blocksize] scales
output_scales = torch.zeros(
(input.size(0), input.size(1) // blocksize),
device=input.device,
dtype=input_scales.dtype,
)
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
# happened for offset purposes.
# There are K//blocksize scales, padded to groups of 4.
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
# (Very) slow reference implementation using horrifying loops.
for i in range(input.size(0)):
for j in range(input.size(1) // blocksize):
# which 128x4 tile of scaling factors am I in
scale_tile_h = i // 128
scale_tile_w = j // 4
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
# indices within the tile - use nomenclature directly from cublas docs
outer = i % 128 # "outer" in cublas docs
inner = j % 4 # "inner" in cublas docs
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
# anyway, and so 1B == 1 value => use offset directly.
# Formula directly from cublas docs in 3.1.4.3.2
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
output_scales[i, j] = input_scales[offset]
return output_scales
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
# expand scales
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
# de-scale and convert
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
return x_f32.to(torch.bfloat16)
def to_blocked(input_matrix) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.

View File

@ -119,11 +119,9 @@ CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
LOG_SUFFIX = ""
PYTEST_SINGLE_TEST = ""
REPEAT_COUNT = 0
RERUN_DISABLED_TESTS = False
RUN_PARALLEL = 0
SHOWLOCALS = False
SLOW_TESTS_FILE = ""
TEST_BAILOUTS = False
@ -950,6 +948,13 @@ def prof_meth_call(*args, **kwargs):
torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[method-assign]
torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[method-assign]
def _get_test_report_path():
# allow users to override the test file location. We need this
# because the distributed tests run the same test file multiple
# times with different configurations.
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
test_source = override if override is not None else 'python-unittest'
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
@ -957,11 +962,9 @@ def parse_cmd_line_args():
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
global LOG_SUFFIX
global PYTEST_SINGLE_TEST
global REPEAT_COUNT
global RERUN_DISABLED_TESTS
global RUN_PARALLEL
global SHOWLOCALS
global SLOW_TESTS_FILE
global TEST_BAILOUTS
@ -980,10 +983,10 @@ def parse_cmd_line_args():
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', type=str)
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
@ -1010,15 +1013,10 @@ def parse_cmd_line_args():
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
if args.save_xml is None and IS_CI:
args.xml_dir = get_report_dir(sys.argv[0], args.log_suffix, args.use_pytest)
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test
@ -1185,37 +1183,19 @@ def lint_test_case_extension(suite):
return succeed
def get_report_dir(test_name: str, log_suffix: Optional[str], is_pytest: bool) -> str:
"""Generates a test report directory path. Test name does not need to be
sanitized."""
# total path = test-reports/test_source+log_suffix/test_filename
# Base path
test_source = "python-unittest"
if is_pytest:
test_source = "python-pytest"
# allow users to override the test file location. We need this
# because the distributed tests run the same test file multiple
# times with different configurations.
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
if override is not None:
test_source = override
# Add log suffix to if provided
if log_suffix and log_suffix != "":
test_source = test_source + log_suffix
test_report_dir = os.path.join('test-reports', test_source)
# Add test file name to path
test_filename = sanitize_test_filename(test_name)
test_report_dir = os.path.join(test_report_dir, test_filename)
os.makedirs(test_report_dir, exist_ok=True)
return test_report_dir
def get_report_path(report_dir: str, test_filename: str) -> str:
return os.path.join(report_dir, f"{sanitize_test_filename(test_filename)}-{os.urandom(8).hex()}.xml")
def get_report_path(argv=None, pytest=False):
if argv is None:
argv = UNITTEST_ARGS
test_filename = sanitize_test_filename(argv[0])
test_report_path = TEST_SAVE_XML
test_report_path = os.path.join(test_report_path, test_filename)
if pytest:
test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(test_report_path, exist_ok=True)
test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
return test_report_path
os.makedirs(test_report_path, exist_ok=True)
return test_report_path
def sanitize_pytest_xml(xml_file: str):
@ -1343,22 +1323,11 @@ def run_tests(argv=None):
assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
len(failed_tests), '\n\t'.join(failed_tests))
elif RUN_PARALLEL > 1:
test_cases = discover_test_cases_recursively(suite)
test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
processes = []
for i in range(RUN_PARALLEL):
command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i]
processes.append(subprocess.Popen(command, universal_newlines=True))
failed = False
for p in processes:
failed |= wait_for_process(p) != 0
assert not failed, "Some test shards have failed"
elif USE_PYTEST:
pytest_args = argv + ["--use-main-module"]
test_report_path = ""
if TEST_SAVE_XML:
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
test_report_path = get_report_path(pytest=True)
print(f'Test results will be stored in {test_report_path}')
pytest_args.append(f'--junit-xml-reruns={test_report_path}')
if PYTEST_SINGLE_TEST:
@ -1402,7 +1371,7 @@ def run_tests(argv=None):
def printErrors(self) -> None:
super().printErrors()
self.printErrorList("XPASS", self.unexpectedSuccesses)
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
test_report_path = get_report_path()
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print(f'Test results will be stored in {test_report_path}')

View File

@ -311,7 +311,11 @@ def escape(n):
def is_cuda_tensor(obj):
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
return (
isinstance(obj, torch.Tensor) and
obj.device.type == "cuda" and
not isinstance(obj, torch._subclasses.FakeTensor)
)
def cuda_allocation_context():
snapshot = torch.cuda.memory._snapshot()