mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (#162209)
## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: https://github.com/pytorch/FBGEMM/pull/4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - https://github.com/pytorch/FBGEMM/pull/4816 - https://github.com/pytorch/FBGEMM/pull/4820 - https://github.com/pytorch/FBGEMM/pull/4821 - https://github.com/pytorch/FBGEMM/pull/4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162209 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5985e28912
commit
b6d0a9ea90
@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
set(USE_FBGEMM_GENAI off)
|
||||
endif()
|
||||
|
||||
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100
|
||||
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a")
|
||||
message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100")
|
||||
set(USE_FBGEMM_GENAI ON)
|
||||
endif()
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
|
@ -255,48 +255,77 @@ endif()
|
||||
# FBGEMM GenAI
|
||||
IF(USE_FBGEMM_GENAI)
|
||||
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
|
||||
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
list(FILTER fbgemm_genai_native_cuda_cu INCLUDE REGEX ${FBGEMM_CUTLASS_KERNELS_REGEX})
|
||||
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cpp
|
||||
"${FBGEMM_GENAI_SRCS}/common/*.cpp"
|
||||
)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
# Combine all source files into a single list
|
||||
list(APPEND fbgemm_genai_all_sources
|
||||
${fbgemm_genai_native_cuda_cu}
|
||||
${fbgemm_genai_native_cuda_cpp}
|
||||
)
|
||||
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
# Now, create the library and provide the sources at the same time
|
||||
add_library(fbgemm_genai OBJECT ${fbgemm_genai_all_sources})
|
||||
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_GENAI_DIR}/include/
|
||||
${FBGEMM_GENAI_DIR}/common/include/
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
else()
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -639,6 +668,13 @@ if(USE_CUDA AND NOT USE_ROCM)
|
||||
add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
if(USE_FBGEMM_GENAI)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
|
@ -1551,7 +1551,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
}
|
||||
|
||||
namespace {
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
scale.dim() == 1,
|
||||
@ -1585,9 +1586,66 @@ namespace {
|
||||
"scale must have the same first dimension as mat for arg ",
|
||||
arg_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
// that are converted to blocked padded format individually,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
// so we can check the exact expected scale sizes here without a d2h sync.
|
||||
auto round_up = [](auto x, auto y) {
|
||||
return ((x + y - 1) / y) * y;
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor
|
||||
@ -1612,8 +1670,8 @@ const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/false);
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
@ -1646,10 +1704,12 @@ bool use_fast_accum) {
|
||||
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK(
|
||||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"Both scale_a and scale_b must be float (fp32) tensors.");
|
||||
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
|
||||
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
|
||||
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
|
||||
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
@ -1660,6 +1720,32 @@ bool use_fast_accum) {
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// MXFP8 grouped GEMM dispatching
|
||||
bool is_mx8mx8bf16 = (
|
||||
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
|
||||
);
|
||||
TORCH_CHECK(out_dtype == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
bool b_is_3d = mat_b.dim() == 3;
|
||||
bool is_2d_2d = a_is_2d && b_is_2d;
|
||||
bool is_2d_3d = a_is_2d && b_is_3d;
|
||||
TORCH_CHECK(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
@ -1691,6 +1777,7 @@ bool use_fast_accum) {
|
||||
#else
|
||||
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
@ -1638,6 +1638,10 @@ if(USE_CUDA)
|
||||
# order of the libraries in the linker call matters here when statically
|
||||
# linking; libculibos and cublas must be last.
|
||||
target_link_libraries(torch_cuda PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
|
||||
if(USE_FBGEMM_GENAI)
|
||||
# Link fbgemm_genai to torch_cuda (only for (1) CUDA build for SM100).
|
||||
target_link_libraries(torch_cuda PRIVATE fbgemm_genai)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ XPU library.
|
||||
@ -1759,9 +1763,10 @@ if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
|
||||
|
||||
if(USE_FBGEMM_GENAI)
|
||||
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
|
||||
if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
|
||||
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
|
||||
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})
|
||||
|
@ -30,6 +30,7 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
|
||||
IS_SM90,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -55,7 +56,13 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ceil_div, to_blocked
|
||||
from torch.testing._internal.common_quantized import (
|
||||
_f32_to_floatx_unpacked,
|
||||
_floatx_unpacked_to_f32,
|
||||
ceil_div, to_blocked,
|
||||
to_mxfp8,
|
||||
generate_jagged_offs,
|
||||
)
|
||||
|
||||
_IS_SM8X = False
|
||||
if TEST_CUDA:
|
||||
@ -771,6 +778,7 @@ class TestMatmulCuda(TestCase):
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices"
|
||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
||||
mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+"
|
||||
|
||||
# avoid division by zero when calculating scale
|
||||
EPS = 1e-12
|
||||
@ -901,6 +909,8 @@ def to_fp8_saturated(
|
||||
|
||||
return x.to(fp8_dtype)
|
||||
|
||||
|
||||
|
||||
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes the error between two tensors in dB.
|
||||
|
||||
@ -1045,6 +1055,167 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [2048, 2049])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [16640])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K):
|
||||
torch.manual_seed(42)
|
||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
G, total_K, multiple_of=32, device="cuda"
|
||||
)
|
||||
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_list = []
|
||||
w_list = []
|
||||
x_blocked_scale_list = []
|
||||
w_blocked_scale_list = []
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
||||
)
|
||||
curr_group_end_offset = input_group_end_offsets[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
x_slice = X[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
w_slice = W[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (N, K_group)
|
||||
x_scale_slice, xq_slice = to_mxfp8(
|
||||
x_slice
|
||||
) # scale shape -> (M, K_group // 32)
|
||||
w_scale_slice, wq_slice = to_mxfp8(
|
||||
w_slice
|
||||
) # scale shape -> (N, K_group // 32)
|
||||
x_list.append(xq_slice)
|
||||
w_list.append(wq_slice)
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_scale_slice_blocked = to_blocked(
|
||||
x_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
w_scale_slice_blocked = to_blocked(
|
||||
w_scale_slice
|
||||
) # (round_up(N, 128), round_up(K_group//32, 4))
|
||||
x_blocked_scale_list.append(x_scale_slice_blocked)
|
||||
w_blocked_scale_list.append(w_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
xq = torch.cat(x_list, dim=1).contiguous()
|
||||
wq = torch.cat(w_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
||||
M_rounded = round_up(M, 128)
|
||||
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
||||
|
||||
# Combine all WQ groups blocked scales into one tensor.
|
||||
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
||||
N_rounded = round_up(N, 128)
|
||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||
|
||||
# Compute mxfp8 grouped mm output
|
||||
y_mxfp8 = torch._scaled_grouped_mm(
|
||||
xq, # (M, total_K)
|
||||
wq.transpose(-2, -1), # (total_K, N)
|
||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||
offs=input_group_end_offsets, # (G,)
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# bf16 reference output
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Assert no NaNs
|
||||
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
|
||||
|
||||
# Assert outputs are close
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [16640])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [4096])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K):
|
||||
torch.manual_seed(42)
|
||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||
# 2D inputs with groups along M, 3D weights.
|
||||
block_size = 32
|
||||
total_M = M # Alias for clarity that M dim contains groups.
|
||||
X = torch.randn((total_M, K), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
W = torch.randn((G, N, K), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
G, total_M, multiple_of=32, device="cuda"
|
||||
)
|
||||
|
||||
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
for i in range(G):
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
w_scale = to_blocked(w_scale)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
|
||||
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
x_scale, xq = to_mxfp8(x_slice)
|
||||
x_scale = to_blocked(x_scale)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
|
||||
# Compute mxfp8 grouped gemm.
|
||||
y_mxfp8 = torch._scaled_grouped_mm(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
x_scale,
|
||||
w_scale,
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Compute reference bf16 grouped gemm.
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X,
|
||||
W.transpose(-2, -1),
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Assert outputs are close.
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_scaled_mm_vs_emulated(self, base_dtype):
|
||||
|
2
third_party/fbgemm
vendored
2
third_party/fbgemm
vendored
Submodule third_party/fbgemm updated: 21c7d30c52...4b39c551ef
@ -7424,17 +7424,17 @@ def _meta_grouped_mm_common(
|
||||
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
|
||||
torch._check(
|
||||
mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
|
||||
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
|
||||
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
|
||||
)
|
||||
|
||||
torch._check(
|
||||
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
|
||||
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
|
||||
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
|
||||
)
|
||||
|
||||
mat_a_is_2d = mat_a.dim() == 2
|
||||
@ -7458,11 +7458,11 @@ def _meta_grouped_mm_common(
|
||||
|
||||
torch._check(
|
||||
is_row_major(mat_a),
|
||||
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
|
||||
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
|
||||
)
|
||||
torch._check(
|
||||
is_col_major(mat_b),
|
||||
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
|
||||
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
|
||||
)
|
||||
|
||||
def check_valid_strides(mat_name, mat):
|
||||
@ -7474,7 +7474,7 @@ def _meta_grouped_mm_common(
|
||||
):
|
||||
torch._check(
|
||||
mat_stride[end_dim] % alignment == 0,
|
||||
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
|
||||
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
|
||||
)
|
||||
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
|
||||
1, mat.shape[end_dim]
|
||||
@ -7494,41 +7494,81 @@ def _meta_grouped_mm_common(
|
||||
|
||||
if scale_a is not None and scale_b is not None:
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
|
||||
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
|
||||
(scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
|
||||
or (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
),
|
||||
lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
|
||||
)
|
||||
is_mxfp8 = (
|
||||
scale_a.dtype == torch.float8_e8m0fnu
|
||||
and scale_b.dtype == torch.float8_e8m0fnu
|
||||
)
|
||||
|
||||
def round_up(x, y):
|
||||
"""Rounds up x to nearest multiple of y"""
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
|
||||
if mat.dim() == 2:
|
||||
torch._check(
|
||||
scale.dim() == 1,
|
||||
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.is_contiguous(),
|
||||
lambda: f"Expected {scale_name} to be contiguous.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
|
||||
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
|
||||
)
|
||||
# For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
# that are converted to blocked padded format individually. At compile time we don't know
|
||||
# the group sizes yet, so we don't know the expect size of the blocked format scale.
|
||||
# This limits what we can check here.
|
||||
if is_mxfp8:
|
||||
torch._check(
|
||||
scale.dim() == mat.dim(),
|
||||
lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 1,
|
||||
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
|
||||
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 2,
|
||||
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.stride(1) == 1,
|
||||
scale.stride(-1) == 1,
|
||||
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[0],
|
||||
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[1] == mat.shape[1 + scaled_dim],
|
||||
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
|
||||
)
|
||||
# For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
|
||||
# scale sizes at compile time.
|
||||
if is_mxfp8:
|
||||
torch._check(
|
||||
mat.ndim == scale.ndim,
|
||||
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
|
||||
)
|
||||
# TODO: This logic only holds for RHS tensor in 2d-3d case.
|
||||
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
|
||||
G, K, N = scale.shape
|
||||
block_size = 32
|
||||
blocked_K = round_up(K / block_size, 4)
|
||||
blocked_N = round_up(N, 128)
|
||||
torch._check(
|
||||
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
|
||||
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 2,
|
||||
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[1] == mat.shape[1 + scaled_dim],
|
||||
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
|
||||
)
|
||||
|
||||
scale_multiplier = (
|
||||
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
|
||||
|
@ -41,6 +41,7 @@ IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_ca
|
||||
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
|
||||
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
|
||||
IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0))
|
||||
IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0))
|
||||
|
||||
def evaluate_gfx_arch_within(arch_list):
|
||||
if not torch.cuda.is_available():
|
||||
@ -129,9 +130,17 @@ def evaluate_platform_supports_mx_gemm():
|
||||
return SM100OrLater
|
||||
return False
|
||||
|
||||
def evaluate_platform_supports_mxfp8_grouped_gemm():
|
||||
if torch.cuda.is_available() and not torch.version.hip:
|
||||
built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show()
|
||||
return built_with_fbgemm_genai and IS_SM100
|
||||
return False
|
||||
|
||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
|
||||
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
|
||||
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
|
||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
|
||||
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm())
|
||||
|
||||
if TEST_NUMBA:
|
||||
try:
|
||||
|
@ -479,3 +479,110 @@ def to_blocked(input_matrix) -> torch.Tensor:
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
|
||||
return rearranged.flatten()
|
||||
|
||||
# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
|
||||
def to_mxfp8(
|
||||
data_hp: torch.Tensor,
|
||||
block_size: int = 32,
|
||||
):
|
||||
assert data_hp.dtype in (
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
), f"{data_hp.dtype} is not supported yet"
|
||||
assert (
|
||||
data_hp.shape[-1] % block_size == 0
|
||||
), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
|
||||
assert data_hp.is_contiguous(), "unsupported"
|
||||
|
||||
orig_shape = data_hp.shape
|
||||
data_hp = data_hp.reshape(
|
||||
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
|
||||
)
|
||||
|
||||
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
|
||||
|
||||
data_hp = data_hp.to(torch.float32)
|
||||
max_abs = max_abs.to(torch.float32)
|
||||
|
||||
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
|
||||
max_pos = F8E4M3_MAX
|
||||
|
||||
# RCEIL
|
||||
def _to_mx_rceil(
|
||||
data_hp: torch.Tensor,
|
||||
max_abs: torch.Tensor,
|
||||
max_pos: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
E8M0_EXPONENT_BIAS = 127
|
||||
descale = max_abs / max_pos
|
||||
exponent = torch.where(
|
||||
torch.isnan(descale),
|
||||
0xFF, # Handle biased exponent for nan
|
||||
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
|
||||
(
|
||||
torch.clamp(
|
||||
torch.ceil(torch.log2(descale)),
|
||||
min=-E8M0_EXPONENT_BIAS,
|
||||
max=E8M0_EXPONENT_BIAS,
|
||||
)
|
||||
+ E8M0_EXPONENT_BIAS
|
||||
).to(torch.uint8),
|
||||
)
|
||||
|
||||
descale_fp = torch.where(
|
||||
exponent == 0,
|
||||
1.0,
|
||||
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
|
||||
)
|
||||
|
||||
# scale and saturated cast the data elements to max of target dtype
|
||||
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
|
||||
return exponent, data_lp
|
||||
|
||||
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
|
||||
|
||||
# cast to target dtype
|
||||
data_lp = data_lp.to(torch.float8_e4m3fn)
|
||||
# need to reshape at the end to help inductor fuse things
|
||||
data_lp = data_lp.reshape(orig_shape)
|
||||
|
||||
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
||||
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
|
||||
return scale_e8m0_biased, data_lp
|
||||
|
||||
# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310
|
||||
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
|
||||
"""
|
||||
Utility function for tests and benchmarks.
|
||||
|
||||
Generates a tensor of length E, containing random values divisible by `multiple_of`,
|
||||
from 0 to M, in sorted order, and where the final value in the tensor is always M.
|
||||
Args:
|
||||
E (int): The length of the tensor.
|
||||
M (int): The maximum value in the tensor.
|
||||
Returns:
|
||||
torch.Tensor: A tensor of length E with the specified properties.
|
||||
"""
|
||||
import random
|
||||
|
||||
# Ensure M is divisible by 16
|
||||
if M % multiple_of != 0:
|
||||
raise ValueError(f"M must be divisible by {multiple_of}")
|
||||
|
||||
# Generate a list of possible values
|
||||
possible_values = list(range(multiple_of, M + 1, multiple_of))
|
||||
|
||||
# If E is larger than the number of possible values, raise an error
|
||||
if E > len(possible_values):
|
||||
raise ValueError("E cannot be larger than the number of possible values")
|
||||
|
||||
# Randomly select E - 1 values from the possible values (excluding M)
|
||||
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
|
||||
|
||||
# Append M to the selected values
|
||||
selected_values = torch.cat((selected_values, torch.tensor([M])))
|
||||
|
||||
# Sort the selected values
|
||||
selected_values, _ = torch.sort(selected_values)
|
||||
|
||||
return selected_values.to(dtype).to(device)
|
||||
|
Reference in New Issue
Block a user