mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 18:04:58 +08:00
Compare commits
16 Commits
revert-cpp
...
csl/remove
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e3e4c0940 | |||
| 551921d484 | |||
| b5189e269e | |||
| 3895ce093f | |||
| 8aa087a29d | |||
| 7379972cc0 | |||
| b903018c26 | |||
| 21b48f8dfa | |||
| 009ea77234 | |||
| 0e46a10aa7 | |||
| a25818cf7e | |||
| e3e93c7107 | |||
| 1abfa5f70b | |||
| 687c15c0b3 | |||
| 895795f07c | |||
| 2dc56456cb |
@ -100,6 +100,8 @@ COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
# Only build aoti cpp tests when INDUCTOR_BENCHMARKS is set to True
|
||||
ENV BUILD_AOT_INDUCTOR_TEST ${INDUCTOR_BENCHMARKS}
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
|
||||
@ -460,28 +460,18 @@ test_inductor_shard() {
|
||||
--verbose
|
||||
}
|
||||
|
||||
test_inductor_aoti() {
|
||||
# docker build uses bdist_wheel which does not work with test_aot_inductor
|
||||
# TODO: need a faster way to build
|
||||
test_inductor_aoti_cpp() {
|
||||
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
|
||||
# We need to hipify before building again
|
||||
python3 tools/amd_build/build_amd.py
|
||||
fi
|
||||
if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then
|
||||
BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python -m pip install --no-build-isolation -v -e .)
|
||||
# TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}")
|
||||
else
|
||||
BUILD_COMMAND=(python -m pip install --no-build-isolation -v -e .)
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}")
|
||||
fi
|
||||
|
||||
# aoti cmake custom command requires `torch` to be installed
|
||||
# initialize the cmake build cache and install torch
|
||||
/usr/bin/env "${BUILD_COMMAND[@]}"
|
||||
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
|
||||
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
|
||||
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
@ -1776,7 +1766,7 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
|
||||
install_torchvision
|
||||
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
|
||||
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
|
||||
test_inductor_aoti
|
||||
test_inductor_aoti_cpp
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
install_torchvision
|
||||
|
||||
@ -7,12 +7,9 @@ if "%DESIRED_PYTHON%" == "3.13t" (
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.13.0/python-3.13.0-amd64.exe"
|
||||
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
|
||||
set PYTHON_EXEC="python3.13t"
|
||||
) else if "%DESIRED_PYTHON%"=="3.14" (
|
||||
echo Python version is set to 3.14 or 3.14t
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
|
||||
) else if "%DESIRED_PYTHON%"=="3.14t" (
|
||||
echo Python version is set to 3.14 or 3.14t
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0-amd64.exe"
|
||||
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
|
||||
set PYTHON_EXEC="python3.14t"
|
||||
) else (
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@ -72,7 +72,7 @@ Elaborating Further:
|
||||
|
||||
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
|
||||
|
||||

|
||||

|
||||
|
||||
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
|
||||
computation by a huge amount.
|
||||
@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc
|
||||
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
|
||||
You get the best of speed and flexibility for your crazy research.
|
||||
|
||||

|
||||

|
||||
|
||||
### Python First
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set(fbgemm_genai_cuh
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
|
||||
@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_f4_f4_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& global_scale_a,
|
||||
const Tensor& scale_b,
|
||||
const Tensor& global_scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
Tensor& out) {
|
||||
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
|
||||
// Typing checks
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
|
||||
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
|
||||
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
|
||||
|
||||
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out,
|
||||
global_scale_a.mul(global_scale_b)
|
||||
);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
|
||||
#endif
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// if {mx,nv}fp4, will need to modify K later
|
||||
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
|
||||
int blocksize = 32;
|
||||
// check for nvfp4 vs. mxfp4 to fix blocksize
|
||||
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
|
||||
blocksize = 16;
|
||||
}
|
||||
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
|
||||
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
|
||||
);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
if (is_fp4) {
|
||||
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
|
||||
// reported K. Reverse that adjustment.
|
||||
const int fp4_elems_per_byte = 2;
|
||||
K *= fp4_elems_per_byte;
|
||||
}
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_K = round_up(K/blocksize, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
|
||||
"but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
|
||||
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else if (using_mx) {
|
||||
_check_scales_blocked(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
@ -411,9 +471,10 @@ namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::NVFP4_NVFP4: {
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0], /* block-scale A */
|
||||
scale_a[1], /* global-scale A */
|
||||
scale_b[0], /* block-scale B */
|
||||
scale_b[1], /* global-scale B */
|
||||
offs.value(),
|
||||
std::nullopt, /* bias */
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
|
||||
@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
at::assert_no_internal_overlap(self);
|
||||
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
|
||||
const auto need_reshape = self.ndimension() > 4;
|
||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
|
||||
@ -153,8 +154,16 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
feeds[meanPlaceholder.getMPSGraphTensor()] = meanPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self);
|
||||
// Handle non-contiguous output tensors by creating a contiguous temporary
|
||||
const auto needs_gather = needsGather(self);
|
||||
Tensor self_ = needs_gather ? at::empty_like(self, MemoryFormat::Contiguous) : self;
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self_);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
|
||||
// Copy results back to original non-contiguous output
|
||||
if (needs_gather) {
|
||||
self.copy_(self_);
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
|
||||
@ -1358,9 +1358,15 @@ if(BUILD_TEST)
|
||||
)
|
||||
else()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy ${CMAKE_BINARY_DIR}/test_lazy)
|
||||
# NativeRT is disabled
|
||||
# add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_abi_check ${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_inference ${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
|
||||
if(USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
|
||||
if(NOT WIN32)
|
||||
@ -1378,16 +1384,6 @@ if(BUILD_TEST)
|
||||
${CMAKE_BINARY_DIR}/test_mobile_nnc
|
||||
)
|
||||
endif()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy
|
||||
${CMAKE_BINARY_DIR}/test_lazy)
|
||||
endif()
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_abi_check
|
||||
${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_inference
|
||||
${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
# Skip on windows
|
||||
if(WIN32)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
@ -30,8 +35,15 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are written in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main sleef)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Disable unused-variable warnings for variables that are only used to test compilation
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
|
||||
|
||||
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
|
||||
@ -41,12 +53,17 @@ foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main sleef)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
|
||||
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
|
||||
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-variable)
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-but-set-variable)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -2,10 +2,27 @@
|
||||
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
#include <iostream>
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
template <typename T>
|
||||
void ExpectVecEqual(
|
||||
const at::vec::Vectorized<T>& expected,
|
||||
const at::vec::Vectorized<T>& actual) {
|
||||
using Vec = at::vec::Vectorized<T>;
|
||||
// Have to use std::vector for comparison because at::vec::Vectorized doesn't
|
||||
// support operator[] on aarch64
|
||||
std::vector<T> expected_data(Vec::size());
|
||||
std::vector<T> actual_data(Vec::size());
|
||||
|
||||
expected.store(expected_data.data());
|
||||
actual.store(actual_data.data());
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_data[i], actual_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestAdd) {
|
||||
using Vec = at::vec::Vectorized<int>;
|
||||
std::vector<int> a(1024, 1);
|
||||
@ -16,9 +33,7 @@ TEST(TestVec, TestAdd) {
|
||||
std::vector<int> expected(1024, 3);
|
||||
Vec expected_vec = Vec::loadu(expected.data());
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMax) {
|
||||
@ -30,9 +45,7 @@ TEST(TestVec, TestMax) {
|
||||
Vec actual_vec = at::vec::maximum(a_vec, b_vec);
|
||||
Vec expected_vec = b_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMin) {
|
||||
@ -44,9 +57,7 @@ TEST(TestVec, TestMin) {
|
||||
Vec actual_vec = at::vec::minimum(a_vec, b_vec);
|
||||
Vec expected_vec = a_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestConvert) {
|
||||
@ -58,9 +69,7 @@ TEST(TestVec, TestConvert) {
|
||||
auto actual_vec = at::vec::convert<float>(a_vec);
|
||||
auto expected_vec = b_vec;
|
||||
|
||||
for (int i = 0; i < at::vec::Vectorized<int>::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestClampMin) {
|
||||
@ -72,9 +81,7 @@ TEST(TestVec, TestClampMin) {
|
||||
Vec actual_vec = at::vec::clamp_min(a_vec, min_vec);
|
||||
Vec expected_vec = min_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
} // namespace aot_inductor
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
set(AOT_INDUCTOR_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_inference)
|
||||
|
||||
# Build custom TorchScript op for AOTInductor
|
||||
@ -8,27 +7,12 @@ set_target_properties(aoti_custom_class PROPERTIES
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_CUDA)
|
||||
elseif(USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
endif()
|
||||
|
||||
# Link against LibTorch
|
||||
target_link_libraries(aoti_custom_class torch)
|
||||
|
||||
# the custom command that generates the TorchScript module
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
# This script requires the torch package to be installed.
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
)
|
||||
add_custom_target(aoti_script_model ALL
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(aoti_script_model aoti_custom_class)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(INDUCTOR_TEST_SRCS
|
||||
${AOT_INDUCTOR_TEST_ROOT}/test.cpp
|
||||
@ -37,23 +21,12 @@ set(INDUCTOR_TEST_SRCS
|
||||
add_executable(test_aoti_inference
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${INDUCTOR_TEST_SRCS}
|
||||
data.pt
|
||||
script_data.pt
|
||||
script_model_cpu.pt
|
||||
script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class aoti_script_model)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class)
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_aoti_inference PRIVATE USE_GTEST)
|
||||
|
||||
# Define a custom command to generate the library
|
||||
add_custom_command(
|
||||
OUTPUT data.pt
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
)
|
||||
|
||||
target_link_libraries(test_aoti_inference PRIVATE
|
||||
torch
|
||||
gtest_main
|
||||
@ -71,6 +44,10 @@ target_compile_definitions(test_aoti_inference PRIVATE
|
||||
CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-but-set-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-function)
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS test_aoti_inference DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
@ -28,6 +30,64 @@
|
||||
|
||||
namespace {
|
||||
|
||||
// Function to check if test data files exist and are valid
|
||||
bool testDataFilesExist() {
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
std::array<std::string, 4> required_files = {
|
||||
"data.pt",
|
||||
"script_data.pt",
|
||||
"script_model_cpu.pt",
|
||||
"script_model_cuda.pt"};
|
||||
|
||||
for (const auto& filename : required_files) {
|
||||
std::string filepath = bindir + "/" + filename;
|
||||
std::ifstream file(filepath);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to ensure test data files are generated at runtime
|
||||
void ensureTestDataGenerated() {
|
||||
static std::once_flag generated_flag;
|
||||
std::call_once(generated_flag, []() {
|
||||
// Only generate if files don't exist or are placeholders
|
||||
if (testDataFilesExist()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
|
||||
// Calculate path to source directory: build/test_aoti_inference -> build ->
|
||||
// pytorch
|
||||
std::string pytorch_root = bindir.substr(0, bindir.find_last_of("/"));
|
||||
pytorch_root = pytorch_root.substr(0, pytorch_root.find_last_of("/"));
|
||||
std::string source_dir = pytorch_root + "/test/cpp/aoti_inference";
|
||||
|
||||
// Generate test data files (data.pt, etc.) by running test.py directly
|
||||
std::string test_script = source_dir + "/test.py";
|
||||
std::string test_data_cmd = "cd " + bindir + " && python " + test_script;
|
||||
std::cout << "Generating test data: " << test_data_cmd << std::endl;
|
||||
int result1 = std::system(test_data_cmd.c_str());
|
||||
if (result1 != 0) {
|
||||
std::cerr << "Warning: Test data generation failed with code " << result1
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Generate model files (script_*.pt) by running compile_model.py directly
|
||||
std::string compile_script = source_dir + "/compile_model.py";
|
||||
std::string models_cmd = "cd " + bindir + " && python " + compile_script;
|
||||
std::cout << "Generating model files: " << models_cmd << std::endl;
|
||||
int result2 = std::system(models_cmd.c_str());
|
||||
if (result2 != 0) {
|
||||
std::cerr << "Warning: Model generation failed with code " << result2
|
||||
<< std::endl;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, at::Tensor> derefTensorConstantMap(
|
||||
torch::inductor::TensorConstantMap tensor_constant_map) {
|
||||
std::unordered_map<std::string, at::Tensor> ret;
|
||||
@ -855,7 +915,6 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
void test_cuda_alloc_test() {
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
@ -895,8 +954,8 @@ void test_cuda_alloc_test() {
|
||||
runner->run(data_loader.attr(inputs_attr.c_str()).toTensorList().vec());
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
class ThreadPool {
|
||||
private:
|
||||
struct Task {
|
||||
@ -1037,86 +1096,96 @@ void test_multi_cuda_streams(const std::string& device) {
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], all_outputs[i][0]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // USE_CUDA
|
||||
#endif // USE_CUDA || USE_ROCM
|
||||
} // namespace
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
TEST(AotInductorTest, BasicTestCpu) {
|
||||
// Test fixture that ensures test data is generated once for all tests
|
||||
class AotInductorTest : public ::testing::Test {
|
||||
public:
|
||||
// This runs once before all tests in this test suite
|
||||
static void SetUpTestSuite() {
|
||||
ensureTestDataGenerated();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AotInductorTest, BasicTestCpu) {
|
||||
test_aoti("cpu", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicScriptTestCpu) {
|
||||
TEST_F(AotInductorTest, BasicScriptTestCpu) {
|
||||
test_aoti_script("cpu");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
test_aoti_package_loader("cpu", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
test_aoti_extract_constants_map("cpu");
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(AotInductorTest, BasicTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicTestCuda) {
|
||||
test_aoti("cuda", true);
|
||||
test_aoti("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicScriptTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicScriptTestCuda) {
|
||||
test_aoti_script("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
test_aoti_package_loader("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
test_aoti_package_loader_multi_gpu("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
test_aoti_user_managed_buffer();
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
TEST_F(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
test_aoti_extract_constants_map("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
TEST_F(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
test_aoti_double_buffering_with_tensor_constants();
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
test_aoti_free_buffer(false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
test_aoti_free_buffer(true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, MultiStreamTestCuda) {
|
||||
TEST_F(AotInductorTest, MultiStreamTestCuda) {
|
||||
test_multi_cuda_streams("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, CudaAllocTestCuda) {
|
||||
TEST_F(AotInductorTest, CudaAllocTestCuda) {
|
||||
test_cuda_alloc_test();
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -892,10 +892,16 @@ fn(torch.randn(5))
|
||||
os.remove(
|
||||
file_path
|
||||
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
orig_maxDiff = unittest.TestCase.maxDiff
|
||||
unittest.TestCase.maxDiff = None
|
||||
try:
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
except Exception:
|
||||
unittest.TestCase.maxDiff = orig_maxDiff
|
||||
raise
|
||||
|
||||
@make_settings_test("torch._dynamo.eval_frame")
|
||||
def test_log_traced_frames(self, records):
|
||||
|
||||
@ -1000,6 +1000,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_compiled_module_truthiness(self):
|
||||
# Test with empty ModuleList
|
||||
original_empty = nn.ModuleList()
|
||||
compiled_empty = torch.compile(original_empty)
|
||||
self.assertEqual(bool(original_empty), bool(compiled_empty))
|
||||
self.assertFalse(bool(compiled_empty))
|
||||
# Test with non-empty ModuleList
|
||||
original_filled = nn.ModuleList([nn.Linear(10, 5)])
|
||||
compiled_filled = torch.compile(original_filled)
|
||||
self.assertEqual(bool(original_filled), bool(compiled_filled))
|
||||
self.assertTrue(bool(compiled_filled))
|
||||
|
||||
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
|
||||
root = guard_manager_wrapper.root
|
||||
cloned_root = root.clone_manager(lambda x: True)
|
||||
|
||||
@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@config.patch(runtime_triton_nan_asserts=True)
|
||||
def test_nan_assert_inside_triton_kernel(self):
|
||||
def fn(x):
|
||||
x = x - 1
|
||||
# Uncomment the following line can trigger the failure of
|
||||
# the device size assertion
|
||||
# x = torch.log(x)
|
||||
return torch.where(x.isnan(), 3.14, x)
|
||||
|
||||
compiled = torch.compile(fn)
|
||||
x = torch.randn(4096, device=GPU_TYPE)
|
||||
out, (code,) = run_and_get_code(compiled, x)
|
||||
self.assertTrue("'NaN or Inf found'" in code)
|
||||
torch.testing.assert_close(out, fn(x))
|
||||
|
||||
@skip_if_cpp_wrapper("skip cpp wrapper")
|
||||
@requires_cuda_and_triton
|
||||
def test_repeat_interleave_decomposition_has_clamp(self):
|
||||
|
||||
@ -12,7 +12,6 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
expectedFailureMPS,
|
||||
expectedFailureMPSPre15,
|
||||
expectedFailureXLA,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
@ -173,7 +172,6 @@ class TestDropoutNNDeviceType(NNTestCase):
|
||||
else:
|
||||
self.assertNotEqual(permuted_inp, out)
|
||||
|
||||
@expectedFailureMPSPre15
|
||||
def test_Dropout(self, device):
|
||||
input = torch.empty(1000)
|
||||
self._test_dropout(nn.Dropout, device, input)
|
||||
|
||||
@ -529,7 +529,7 @@ class TestProfiler(TestCase):
|
||||
found_mm = True
|
||||
if "gemm" in e.name.lower() or "Cijk" in e.name:
|
||||
found_gemm = True
|
||||
if "memcpy" in e.name.lower():
|
||||
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
|
||||
found_memcpy = True
|
||||
if use_cuda:
|
||||
self.assertTrue(found_gemm)
|
||||
|
||||
@ -755,6 +755,8 @@ def run_test_retries(
|
||||
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
|
||||
) as f:
|
||||
current_failure = f.read()
|
||||
if current_failure == "null":
|
||||
current_failure = f"'{test_file}'"
|
||||
except FileNotFoundError:
|
||||
print_to_file(
|
||||
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
|
||||
@ -791,8 +793,6 @@ def run_test_retries(
|
||||
print_to_file("Retrying single test...")
|
||||
print_items = [] # do not continue printing them, massive waste of space
|
||||
|
||||
if "null" in num_failures:
|
||||
num_failures[f"'{test_file}'"] = num_failures.pop("null")
|
||||
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
|
||||
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
|
||||
if len(flaky_failures) > 0:
|
||||
|
||||
@ -7846,6 +7846,45 @@ class TestMPS(TestCaseMPS):
|
||||
y = torch.normal(torch.zeros(shape, device="mps"), torch.ones(shape, device="mps"))
|
||||
self.assertNotEqual(y[0], y[1])
|
||||
|
||||
def test_random_ops_noncontiguous(self):
|
||||
"""Test random in-place operations on non-contiguous tensors.
|
||||
|
||||
All random in-place operations should work on non-contiguous tensors.
|
||||
See issues #165257 and #124029.
|
||||
"""
|
||||
# Test each random in-place operation
|
||||
ops = [
|
||||
("normal_", lambda t: t.normal_(0, 1)),
|
||||
("uniform_", lambda t: t.uniform_(0, 1)),
|
||||
("exponential_", lambda t: t.exponential_(1.0)),
|
||||
("bernoulli_", lambda t: t.bernoulli_(0.5)),
|
||||
("random_", lambda t: t.random_()),
|
||||
("random_with_to", lambda t: t.random_(10)),
|
||||
("random_with_range", lambda t: t.random_(0, 10)),
|
||||
]
|
||||
|
||||
for name, op_func in ops:
|
||||
with self.subTest(operation=name):
|
||||
# Create non-contiguous tensor via transpose
|
||||
t_mps = torch.zeros(50, 50, device='mps').T.clone()
|
||||
self.assertFalse(t_mps.is_contiguous(),
|
||||
f"{name}: tensor should be non-contiguous")
|
||||
|
||||
# Apply operation
|
||||
op_func(t_mps)
|
||||
|
||||
# Verify tensor was modified (not all zeros)
|
||||
max_val = t_mps.max().item()
|
||||
self.assertNotEqual(max_val, 0.0,
|
||||
f"{name}: operation failed to modify non-contiguous tensor")
|
||||
|
||||
# Test rand_like specifically (issue #124029)
|
||||
t = torch.ones((3, 2, 2), device='mps').permute(2, 0, 1)
|
||||
self.assertFalse(t.is_contiguous(), "rand_like input should be non-contiguous")
|
||||
result = torch.rand_like(t)
|
||||
self.assertFalse(result.is_contiguous(), "rand_like result should be non-contiguous")
|
||||
self.assertNotEqual(result.max().item(), 0.0, "rand_like should generate non-zero values")
|
||||
|
||||
# Test exponential
|
||||
@unittest.skip("This does not test anything")
|
||||
def test_exponential(self):
|
||||
|
||||
@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
|
||||
_floatx_unpacked_to_f32,
|
||||
ceil_div, to_blocked,
|
||||
to_mxfp8,
|
||||
from_blocked_format,
|
||||
generate_jagged_offs,
|
||||
)
|
||||
|
||||
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
|
||||
uint8_data = uint8_data.contiguous().view(-1)
|
||||
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
|
||||
|
||||
def unpack_uint4(uint8_data) -> torch.Tensor:
|
||||
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
|
||||
# a tensor twice as wide. Useful for dequant operations.
|
||||
shape = list(uint8_data.shape)
|
||||
# 2x packed elements -> single non-packed => adjust shape
|
||||
shape[-1] *= 2
|
||||
out = torch.empty(
|
||||
*shape,
|
||||
device=uint8_data.device,
|
||||
dtype=torch.uint8
|
||||
).view(-1)
|
||||
|
||||
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
|
||||
|
||||
out[1::2] = uint8_data_as_uint8[:] >> 4
|
||||
out[::2] = uint8_data_as_uint8 & 15
|
||||
|
||||
return out.view(shape)
|
||||
|
||||
def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
assert x.dtype == torch.bfloat16
|
||||
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
x = x.view(torch.float4_e2m1fn_x2)
|
||||
return x
|
||||
|
||||
def _convert_to_nvfp4_with_hp_ref(t):
|
||||
# Convert a tensor to nvfp4, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : nvfp4 tensor (2x elements packed into uint8)
|
||||
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
|
||||
# t_global_scale: fp32 tensor-wise global scaling factor
|
||||
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
|
||||
t,
|
||||
16,
|
||||
)
|
||||
t_hp = from_blocked_format(
|
||||
_floatx_unpacked_to_f32(
|
||||
unpack_uint4(t_lp),
|
||||
FP4_EBITS,
|
||||
FP4_MBITS),
|
||||
t_scale,
|
||||
blocksize=16) * t_global_scale
|
||||
|
||||
return t_hp, t_lp, t_scale, t_global_scale
|
||||
|
||||
def _convert_to_mxfp8_with_hp_ref(t):
|
||||
# Convert a tensor to mxfp8, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : fp8_e4m3 tensor
|
||||
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
|
||||
t_scale, t_lp = to_mxfp8(t)
|
||||
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
|
||||
|
||||
return t_hp, t_lp, t_scale
|
||||
|
||||
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
|
||||
# Convert scales to blocked format. either mxfp8 or nvfp4
|
||||
th_list = []
|
||||
t_list = []
|
||||
t_blocked_scale_list = []
|
||||
t_global_scale_list = []
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else offs[group_idx - 1]
|
||||
)
|
||||
curr_group_end_offset = offs[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
t_slice = t[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
if format == 'mxfp8':
|
||||
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
|
||||
elif format == 'nvfp4':
|
||||
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
|
||||
t_slice,
|
||||
)
|
||||
t_global_scale_list.append(tq_global)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
t_list.append(tq_slice)
|
||||
th_list.append(th_slice)
|
||||
|
||||
# Convert scales to blocked format.
|
||||
t_scale_slice_blocked = to_blocked(
|
||||
t_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
t_blocked_scale_list.append(t_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
tq = torch.cat(t_list, dim=1).contiguous()
|
||||
th = torch.cat(th_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
|
||||
MN_rounded = round_up(MN, 128)
|
||||
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
|
||||
|
||||
# Global scales only exist for nvfp4
|
||||
t_global_scales = None
|
||||
if len(t_global_scale_list) > 0:
|
||||
t_global_scales = torch.stack(t_global_scale_list)
|
||||
|
||||
return th, tq, t_blocked_scales, t_global_scales
|
||||
|
||||
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
|
||||
# Build some standard args that are wordy
|
||||
# Note: if/when ROCm support added, need to change swizzle handling
|
||||
kwargs = {
|
||||
'mxfp8': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': ScalingType.BlockWise1x32,
|
||||
'scale_recipe_b': ScalingType.BlockWise1x32,
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
'nvfp4': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
}
|
||||
return kwargs[format]
|
||||
|
||||
class TestFP8Matmul(TestCase):
|
||||
|
||||
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [2048, 2049])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [16640])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
|
||||
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_list = []
|
||||
w_list = []
|
||||
x_blocked_scale_list = []
|
||||
w_blocked_scale_list = []
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
X, M, G, input_group_end_offsets, format=format
|
||||
)
|
||||
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
W, N, G, input_group_end_offsets, format=format
|
||||
)
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
curr_group_end_offset = input_group_end_offsets[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
x_slice = X[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
w_slice = W[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (N, K_group)
|
||||
x_scale_slice, xq_slice = to_mxfp8(
|
||||
x_slice
|
||||
) # scale shape -> (M, K_group // 32)
|
||||
w_scale_slice, wq_slice = to_mxfp8(
|
||||
w_slice
|
||||
) # scale shape -> (N, K_group // 32)
|
||||
x_list.append(xq_slice)
|
||||
w_list.append(wq_slice)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_scale_slice_blocked = to_blocked(
|
||||
x_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
w_scale_slice_blocked = to_blocked(
|
||||
w_scale_slice
|
||||
) # (round_up(N, 128), round_up(K_group//32, 4))
|
||||
x_blocked_scale_list.append(x_scale_slice_blocked)
|
||||
w_blocked_scale_list.append(w_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
xq = torch.cat(x_list, dim=1).contiguous()
|
||||
wq = torch.cat(w_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
||||
M_rounded = round_up(M, 128)
|
||||
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
||||
|
||||
# Combine all WQ groups blocked scales into one tensor.
|
||||
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
||||
N_rounded = round_up(N, 128)
|
||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute mxfp8 grouped mm output
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
xq, # (M, total_K)
|
||||
wq.transpose(-2, -1), # (total_K, N)
|
||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
offs=input_group_end_offsets, # (G,)
|
||||
out_dtype=torch.bfloat16,
|
||||
wrap_v2=wrap_v2
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# bf16 reference output
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Assert no NaNs
|
||||
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
|
||||
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
|
||||
|
||||
# Assert outputs are close
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [16640])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [4096])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||
# 2D inputs with groups along M, 3D weights.
|
||||
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
for i in range(G):
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
w_scale = to_blocked(w_scale)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
def _3d_to_blocked_scaled(W, G, format):
|
||||
wh_list = []
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
w_global_scale_list = []
|
||||
for i in range(G):
|
||||
if format == "mxfp8":
|
||||
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
|
||||
elif format == "nvfp4":
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
|
||||
w_global_scale_list.append(w_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Swizzle scaled
|
||||
# TODO(slayton): gate on cuda/hip
|
||||
w_scale = to_blocked(w_scale)
|
||||
|
||||
wh_list.append(wh)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wh = torch.stack(wh_list, dim=0).contiguous()
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
# Global scales only exist for nvfp4
|
||||
if len(w_global_scale_list) > 0:
|
||||
w_global_scales = torch.stack(w_global_scale_list)
|
||||
else:
|
||||
w_global_scales = None
|
||||
return wh, wq, w_scale, w_global_scales
|
||||
|
||||
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
|
||||
|
||||
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
x_scale, xq = to_mxfp8(x_slice)
|
||||
x_scale = to_blocked(x_scale)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
def _2d_to_blocked_scaled(X, K, G, offs, format):
|
||||
xh_list = []
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
x_global_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
if format == "mxfp8":
|
||||
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
|
||||
elif format == "nvfp4":
|
||||
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
|
||||
x_global_scale_list.append(x_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Compute mxfp8 grouped gemm.
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
x_scale = to_blocked(x_scale)
|
||||
xh_list.append(xh)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xh = torch.cat(xh_list, dim=0).contiguous()
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
xh = xh.view(-1, xh.shape[-1])
|
||||
|
||||
x_global_scales = None
|
||||
if len(x_global_scale_list) > 0:
|
||||
x_global_scales = torch.stack(x_global_scale_list)
|
||||
|
||||
return xh, xq, x_scale, x_global_scales
|
||||
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
|
||||
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute low-precision grouped gemm.
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
x_scale,
|
||||
w_scale,
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
wrap_v2=wrap_v2)
|
||||
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Compute reference bf16 grouped gemm.
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X,
|
||||
W.transpose(-2, -1),
|
||||
xh,
|
||||
wh.transpose(-2, -1),
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Assert outputs are close.
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("fast_accum", [False, True])
|
||||
# AMD does not support non-contiguous inputs yet
|
||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||
# AMD does not support NVFP4
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
|
||||
device = "cuda"
|
||||
|
||||
@ -445,7 +445,7 @@ use_numpy_random_stream = False
|
||||
enable_cpp_guard_manager = True
|
||||
|
||||
# Use C++ guard manager for symbolic shapes
|
||||
enable_cpp_symbolic_shape_guards = False
|
||||
enable_cpp_symbolic_shape_guards = not is_fbcode()
|
||||
|
||||
# Enable tracing through contextlib.contextmanager
|
||||
enable_trace_contextlib = True
|
||||
|
||||
@ -42,7 +42,7 @@ import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -395,6 +395,13 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._initialize()
|
||||
self.training = self._orig_mod.training
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Proxy the len call to the original module
|
||||
if isinstance(self._orig_mod, Sized):
|
||||
return len(self._orig_mod)
|
||||
# Mimic python's default behavior for objects without a length
|
||||
raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
|
||||
|
||||
def _initialize(self) -> None:
|
||||
# Do this stuff in constructor to lower overhead slightly
|
||||
if isinstance(self.dynamo_ctx, DisableContext):
|
||||
|
||||
@ -1793,14 +1793,6 @@ def _aot_stage2b_bw_compile(
|
||||
# tensor which is wrong.
|
||||
|
||||
ph_size = ph_arg.size()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if len(ph_size) == 0 and len(real_stride) > 0:
|
||||
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
|
||||
# (e.g., via squeeze), its stride should be () not (1,).
|
||||
# This mismatch can occur when dynamic shape operations produce
|
||||
# tensors that are later squeezed to 0-d. The stride metadata
|
||||
# may get preserved causing a dimension mismatch (#164814)
|
||||
real_stride = ()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride)
|
||||
|
||||
@ -720,13 +720,22 @@ def check_shape(
|
||||
) -> None:
|
||||
backend = get_current_backend()
|
||||
assert shape is not None
|
||||
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
|
||||
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
|
||||
shape_str = (
|
||||
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
|
||||
)
|
||||
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
|
||||
|
||||
|
||||
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
|
||||
backend = get_current_backend()
|
||||
if backend == "triton":
|
||||
msg = "NaN or Inf found"
|
||||
buffer.writeline(
|
||||
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
|
||||
)
|
||||
|
||||
|
||||
class DataTypePropagation:
|
||||
def __init__(self, body: LoopBody) -> None:
|
||||
self.body = body
|
||||
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
|
||||
assert output_shape is not None
|
||||
check_shape(V.kernel.compute, csevar, output_shape)
|
||||
|
||||
if config.runtime_triton_nan_asserts:
|
||||
check_nan(V.kernel.compute, csevar)
|
||||
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
@ -626,7 +626,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -206,6 +206,9 @@ static_weight_shapes = True
|
||||
# put correctness assertions in generated code
|
||||
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
||||
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
|
||||
runtime_triton_nan_asserts = (
|
||||
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
|
||||
)
|
||||
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
|
||||
|
||||
# Disable by default in fbcode
|
||||
|
||||
@ -3550,24 +3550,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not inductor_meta.get("autotune_pointwise", True) and not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -409,9 +409,10 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
ancestors: OrderedSet[str]
|
||||
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
|
||||
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
last_usage: OrderedSet[str]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
@ -420,22 +421,24 @@ class BaseSchedulerNode:
|
||||
min_order: int
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
mutation_renames: dict[str, str]
|
||||
node: Optional[ir.Operation]
|
||||
outputs: list[SchedulerBuffer]
|
||||
outputs_by_name: dict[str, SchedulerBuffer]
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
self.scheduler = scheduler
|
||||
self.debug_device_str = lambda *args, **kwargs: []
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.node = node
|
||||
self.ancestors = OrderedSet()
|
||||
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs: list[SchedulerBuffer] = [
|
||||
self.outputs = [
|
||||
SchedulerBuffer(
|
||||
scheduler=self.scheduler,
|
||||
node=output,
|
||||
@ -443,16 +446,14 @@ class BaseSchedulerNode:
|
||||
)
|
||||
for output in node.get_outputs()
|
||||
]
|
||||
self.outputs_by_name: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for buf in self.outputs
|
||||
}
|
||||
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
|
||||
|
||||
# mutation_renames for the current node. Due to potential
|
||||
# more mutations happening later, this can be different
|
||||
# to Scheduler.mutation_renames. Also this dict should be small
|
||||
# since only mutation information relevant to the deps for this
|
||||
# node is stored here.
|
||||
self.mutation_renames: dict[str, str] = {}
|
||||
self.mutation_renames = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(name={self.get_name()!r})"
|
||||
@ -2435,6 +2436,34 @@ def pick_loop_order(
|
||||
return order
|
||||
|
||||
|
||||
def _replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeUser:
|
||||
node: Union[BaseSchedulerNode, OutputNode]
|
||||
@ -3336,33 +3365,6 @@ class Scheduler:
|
||||
will force completion of compilation and benchmarking.
|
||||
"""
|
||||
|
||||
def replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, ir.MultiTemplateBuffer
|
||||
@ -3416,40 +3418,47 @@ class Scheduler:
|
||||
assign_origin_node(out_tensorbox, multi_node.origin_node)
|
||||
|
||||
out_buffer.layout = multi_node.layout
|
||||
replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
self._replace_node(out_buffer, multi_node, i, node)
|
||||
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
def _replace_node(
|
||||
self,
|
||||
out_buffer: ir.OperationBuffer,
|
||||
multi_node: ir.MultiTemplateBuffer,
|
||||
i: int,
|
||||
node: SchedulerNode,
|
||||
) -> None:
|
||||
_replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(
|
||||
node.read_writes.reads, node.unmet_dependencies
|
||||
):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
|
||||
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
|
||||
return any(
|
||||
|
||||
@ -74,6 +74,17 @@ def export_compat(
|
||||
if opset_version is None:
|
||||
opset_version = onnx_constants.ONNX_DEFAULT_OPSET
|
||||
|
||||
if isinstance(model, torch.nn.Module):
|
||||
if model.training:
|
||||
warnings.warn(
|
||||
"Exporting a model while it is in training mode. "
|
||||
"Please ensure that this is intended, as it may lead to "
|
||||
"different behavior during inference. "
|
||||
"Calling model.eval() before export is recommended.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(model, torch.export.ExportedProgram):
|
||||
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes
|
||||
# are not used
|
||||
|
||||
@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
|
||||
"__rmod__",
|
||||
"__rsub__",
|
||||
"__rpow__",
|
||||
"bernoulli",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"masked_scatter",
|
||||
|
||||
@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
|
||||
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
|
||||
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
|
||||
# more naturally.
|
||||
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
|
||||
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
|
||||
# Output should be [input.size(0, input.size(1) // blocksize] scales
|
||||
output_scales = torch.zeros(
|
||||
(input.size(0), input.size(1) // blocksize),
|
||||
device=input.device,
|
||||
dtype=input_scales.dtype,
|
||||
)
|
||||
|
||||
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
|
||||
# happened for offset purposes.
|
||||
# There are K//blocksize scales, padded to groups of 4.
|
||||
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
|
||||
|
||||
# (Very) slow reference implementation using horrifying loops.
|
||||
for i in range(input.size(0)):
|
||||
for j in range(input.size(1) // blocksize):
|
||||
# which 128x4 tile of scaling factors am I in
|
||||
scale_tile_h = i // 128
|
||||
scale_tile_w = j // 4
|
||||
|
||||
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
|
||||
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
|
||||
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
|
||||
|
||||
# indices within the tile - use nomenclature directly from cublas docs
|
||||
outer = i % 128 # "outer" in cublas docs
|
||||
inner = j % 4 # "inner" in cublas docs
|
||||
|
||||
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
|
||||
# anyway, and so 1B == 1 value => use offset directly.
|
||||
# Formula directly from cublas docs in 3.1.4.3.2
|
||||
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
|
||||
|
||||
output_scales[i, j] = input_scales[offset]
|
||||
|
||||
return output_scales
|
||||
|
||||
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
|
||||
# expand scales
|
||||
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
|
||||
|
||||
# de-scale and convert
|
||||
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
|
||||
return x_f32.to(torch.bfloat16)
|
||||
|
||||
def to_blocked(input_matrix) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
|
||||
@ -114,8 +114,6 @@ class ProfilingMode(Enum):
|
||||
PROFILING = 3
|
||||
|
||||
# Set by parse_cmd_line_args() if called
|
||||
CI_FUNCTORCH_ROOT = ""
|
||||
CI_PT_ROOT = ""
|
||||
CI_TEST_PREFIX = ""
|
||||
DISABLED_TESTS_FILE = ""
|
||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
||||
@ -959,8 +957,6 @@ def _get_test_report_path():
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
global CI_PT_ROOT
|
||||
global CI_TEST_PREFIX
|
||||
global DISABLED_TESTS_FILE
|
||||
global GRAPH_EXECUTOR
|
||||
@ -1039,10 +1035,8 @@ def parse_cmd_line_args():
|
||||
|
||||
set_rng_seed()
|
||||
|
||||
# CI Prefix path used only on CI environment
|
||||
# CI Prefix path used only on CI environment
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
try:
|
||||
|
||||
@ -311,7 +311,11 @@ def escape(n):
|
||||
|
||||
|
||||
def is_cuda_tensor(obj):
|
||||
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
return (
|
||||
isinstance(obj, torch.Tensor) and
|
||||
obj.device.type == "cuda" and
|
||||
not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
)
|
||||
|
||||
def cuda_allocation_context():
|
||||
snapshot = torch.cuda.memory._snapshot()
|
||||
|
||||
Reference in New Issue
Block a user