mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Compare commits
22 Commits
revert-cpp
...
trunk/f36f
| Author | SHA1 | Date | |
|---|---|---|---|
| f36f372acc | |||
| d9483d4c8d | |||
| fea819ed08 | |||
| 84a2715d34 | |||
| 572cc12b42 | |||
| 1fdef664a5 | |||
| 08ae55021e | |||
| 551921d484 | |||
| b5189e269e | |||
| 3895ce093f | |||
| 8aa087a29d | |||
| 7379972cc0 | |||
| b903018c26 | |||
| 21b48f8dfa | |||
| 009ea77234 | |||
| 0e46a10aa7 | |||
| a25818cf7e | |||
| e3e93c7107 | |||
| 1abfa5f70b | |||
| 687c15c0b3 | |||
| 895795f07c | |||
| 2dc56456cb |
@ -100,6 +100,8 @@ COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
# Only build aoti cpp tests when INDUCTOR_BENCHMARKS is set to True
|
||||
ENV BUILD_AOT_INDUCTOR_TEST ${INDUCTOR_BENCHMARKS}
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
|
||||
@ -460,28 +460,18 @@ test_inductor_shard() {
|
||||
--verbose
|
||||
}
|
||||
|
||||
test_inductor_aoti() {
|
||||
# docker build uses bdist_wheel which does not work with test_aot_inductor
|
||||
# TODO: need a faster way to build
|
||||
test_inductor_aoti_cpp() {
|
||||
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
|
||||
# We need to hipify before building again
|
||||
python3 tools/amd_build/build_amd.py
|
||||
fi
|
||||
if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then
|
||||
BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python -m pip install --no-build-isolation -v -e .)
|
||||
# TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}")
|
||||
else
|
||||
BUILD_COMMAND=(python -m pip install --no-build-isolation -v -e .)
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}")
|
||||
fi
|
||||
|
||||
# aoti cmake custom command requires `torch` to be installed
|
||||
# initialize the cmake build cache and install torch
|
||||
/usr/bin/env "${BUILD_COMMAND[@]}"
|
||||
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
|
||||
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
|
||||
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
@ -1776,7 +1766,7 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
|
||||
install_torchvision
|
||||
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
|
||||
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
|
||||
test_inductor_aoti
|
||||
test_inductor_aoti_cpp
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
install_torchvision
|
||||
|
||||
@ -7,12 +7,9 @@ if "%DESIRED_PYTHON%" == "3.13t" (
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.13.0/python-3.13.0-amd64.exe"
|
||||
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
|
||||
set PYTHON_EXEC="python3.13t"
|
||||
) else if "%DESIRED_PYTHON%"=="3.14" (
|
||||
echo Python version is set to 3.14 or 3.14t
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
|
||||
) else if "%DESIRED_PYTHON%"=="3.14t" (
|
||||
echo Python version is set to 3.14 or 3.14t
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
|
||||
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0-amd64.exe"
|
||||
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
|
||||
set PYTHON_EXEC="python3.14t"
|
||||
) else (
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@ -72,7 +72,7 @@ Elaborating Further:
|
||||
|
||||
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
|
||||
|
||||

|
||||

|
||||
|
||||
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
|
||||
computation by a huge amount.
|
||||
@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc
|
||||
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
|
||||
You get the best of speed and flexibility for your crazy research.
|
||||
|
||||

|
||||

|
||||
|
||||
### Python First
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set(fbgemm_genai_cuh
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
|
||||
@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_f4_f4_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& global_scale_a,
|
||||
const Tensor& scale_b,
|
||||
const Tensor& global_scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
Tensor& out) {
|
||||
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
|
||||
// Typing checks
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
|
||||
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
|
||||
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
|
||||
|
||||
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out,
|
||||
global_scale_a.mul(global_scale_b)
|
||||
);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
|
||||
#endif
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// if {mx,nv}fp4, will need to modify K later
|
||||
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
|
||||
int blocksize = 32;
|
||||
// check for nvfp4 vs. mxfp4 to fix blocksize
|
||||
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
|
||||
blocksize = 16;
|
||||
}
|
||||
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
|
||||
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
|
||||
);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
if (is_fp4) {
|
||||
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
|
||||
// reported K. Reverse that adjustment.
|
||||
const int fp4_elems_per_byte = 2;
|
||||
K *= fp4_elems_per_byte;
|
||||
}
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_K = round_up(K/blocksize, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
|
||||
"but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
|
||||
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else if (using_mx) {
|
||||
_check_scales_blocked(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
@ -411,9 +471,10 @@ namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::NVFP4_NVFP4: {
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0], /* block-scale A */
|
||||
scale_a[1], /* global-scale A */
|
||||
scale_b[0], /* block-scale B */
|
||||
scale_b[1], /* global-scale B */
|
||||
offs.value(),
|
||||
std::nullopt, /* bias */
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
|
||||
@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
at::assert_no_internal_overlap(self);
|
||||
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
|
||||
const auto need_reshape = self.ndimension() > 4;
|
||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
|
||||
@ -153,8 +154,16 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
feeds[meanPlaceholder.getMPSGraphTensor()] = meanPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self);
|
||||
// Handle non-contiguous output tensors by creating a contiguous temporary
|
||||
const auto needs_gather = needsGather(self);
|
||||
Tensor self_ = needs_gather ? at::empty_like(self, MemoryFormat::Contiguous) : self;
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self_);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
|
||||
// Copy results back to original non-contiguous output
|
||||
if (needs_gather) {
|
||||
self.copy_(self_);
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/reshape.h>
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#include <ATen/ops/sum.h>
|
||||
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#include <static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace FLASH_NAMESPACE {
|
||||
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
|
||||
if (batch_size == 0) {
|
||||
auto opts = q.options();
|
||||
at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
|
||||
at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
|
||||
at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
|
||||
at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
|
||||
at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
at::Tensor p = at::empty({0}, opts);
|
||||
if (return_softmax) {
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
|
||||
}
|
||||
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
|
||||
}
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
|
||||
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const int head_size = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
|
||||
if (batch_size == 0) {
|
||||
auto opts = q.options();
|
||||
at::Tensor dq = at::empty_like(q);
|
||||
at::Tensor dk = at::empty_like(k);
|
||||
at::Tensor dv = at::empty_like(v);
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
return {dq, dk, dv, softmax_d};
|
||||
}
|
||||
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
|
||||
@ -1358,9 +1358,15 @@ if(BUILD_TEST)
|
||||
)
|
||||
else()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy ${CMAKE_BINARY_DIR}/test_lazy)
|
||||
# NativeRT is disabled
|
||||
# add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_abi_check ${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_inference ${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
|
||||
if(USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
|
||||
if(NOT WIN32)
|
||||
@ -1378,16 +1384,6 @@ if(BUILD_TEST)
|
||||
${CMAKE_BINARY_DIR}/test_mobile_nnc
|
||||
)
|
||||
endif()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy
|
||||
${CMAKE_BINARY_DIR}/test_lazy)
|
||||
endif()
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_abi_check
|
||||
${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_inference
|
||||
${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -99,6 +99,12 @@ DTensor supports the following types of {class}`Placement` on each {class}`Devic
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: MaskPartial
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: Placement
|
||||
:members:
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
# Skip on windows
|
||||
if(WIN32)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
@ -30,8 +35,15 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are written in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main sleef)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Disable unused-variable warnings for variables that are only used to test compilation
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
|
||||
|
||||
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
|
||||
@ -41,12 +53,17 @@ foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main sleef)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
|
||||
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
|
||||
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-variable)
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-but-set-variable)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -2,10 +2,27 @@
|
||||
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
#include <iostream>
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
template <typename T>
|
||||
void ExpectVecEqual(
|
||||
const at::vec::Vectorized<T>& expected,
|
||||
const at::vec::Vectorized<T>& actual) {
|
||||
using Vec = at::vec::Vectorized<T>;
|
||||
// Have to use std::vector for comparison because at::vec::Vectorized doesn't
|
||||
// support operator[] on aarch64
|
||||
std::vector<T> expected_data(Vec::size());
|
||||
std::vector<T> actual_data(Vec::size());
|
||||
|
||||
expected.store(expected_data.data());
|
||||
actual.store(actual_data.data());
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_data[i], actual_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestAdd) {
|
||||
using Vec = at::vec::Vectorized<int>;
|
||||
std::vector<int> a(1024, 1);
|
||||
@ -16,9 +33,7 @@ TEST(TestVec, TestAdd) {
|
||||
std::vector<int> expected(1024, 3);
|
||||
Vec expected_vec = Vec::loadu(expected.data());
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMax) {
|
||||
@ -30,9 +45,7 @@ TEST(TestVec, TestMax) {
|
||||
Vec actual_vec = at::vec::maximum(a_vec, b_vec);
|
||||
Vec expected_vec = b_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMin) {
|
||||
@ -44,9 +57,7 @@ TEST(TestVec, TestMin) {
|
||||
Vec actual_vec = at::vec::minimum(a_vec, b_vec);
|
||||
Vec expected_vec = a_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestConvert) {
|
||||
@ -58,9 +69,7 @@ TEST(TestVec, TestConvert) {
|
||||
auto actual_vec = at::vec::convert<float>(a_vec);
|
||||
auto expected_vec = b_vec;
|
||||
|
||||
for (int i = 0; i < at::vec::Vectorized<int>::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
TEST(TestVec, TestClampMin) {
|
||||
@ -72,9 +81,7 @@ TEST(TestVec, TestClampMin) {
|
||||
Vec actual_vec = at::vec::clamp_min(a_vec, min_vec);
|
||||
Vec expected_vec = min_vec;
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
}
|
||||
|
||||
} // namespace aot_inductor
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
set(AOT_INDUCTOR_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_inference)
|
||||
|
||||
# Build custom TorchScript op for AOTInductor
|
||||
@ -8,27 +7,12 @@ set_target_properties(aoti_custom_class PROPERTIES
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_CUDA)
|
||||
elseif(USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
endif()
|
||||
|
||||
# Link against LibTorch
|
||||
target_link_libraries(aoti_custom_class torch)
|
||||
|
||||
# the custom command that generates the TorchScript module
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
# This script requires the torch package to be installed.
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
)
|
||||
add_custom_target(aoti_script_model ALL
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(aoti_script_model aoti_custom_class)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(INDUCTOR_TEST_SRCS
|
||||
${AOT_INDUCTOR_TEST_ROOT}/test.cpp
|
||||
@ -37,23 +21,12 @@ set(INDUCTOR_TEST_SRCS
|
||||
add_executable(test_aoti_inference
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${INDUCTOR_TEST_SRCS}
|
||||
data.pt
|
||||
script_data.pt
|
||||
script_model_cpu.pt
|
||||
script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class aoti_script_model)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class)
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_aoti_inference PRIVATE USE_GTEST)
|
||||
|
||||
# Define a custom command to generate the library
|
||||
add_custom_command(
|
||||
OUTPUT data.pt
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
)
|
||||
|
||||
target_link_libraries(test_aoti_inference PRIVATE
|
||||
torch
|
||||
gtest_main
|
||||
@ -71,6 +44,10 @@ target_compile_definitions(test_aoti_inference PRIVATE
|
||||
CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-but-set-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-function)
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS test_aoti_inference DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
@ -28,6 +30,64 @@
|
||||
|
||||
namespace {
|
||||
|
||||
// Function to check if test data files exist and are valid
|
||||
bool testDataFilesExist() {
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
std::array<std::string, 4> required_files = {
|
||||
"data.pt",
|
||||
"script_data.pt",
|
||||
"script_model_cpu.pt",
|
||||
"script_model_cuda.pt"};
|
||||
|
||||
for (const auto& filename : required_files) {
|
||||
std::string filepath = bindir + "/" + filename;
|
||||
std::ifstream file(filepath);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to ensure test data files are generated at runtime
|
||||
void ensureTestDataGenerated() {
|
||||
static std::once_flag generated_flag;
|
||||
std::call_once(generated_flag, []() {
|
||||
// Only generate if files don't exist or are placeholders
|
||||
if (testDataFilesExist()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
|
||||
// Calculate path to source directory: build/test_aoti_inference -> build ->
|
||||
// pytorch
|
||||
std::string pytorch_root = bindir.substr(0, bindir.find_last_of("/"));
|
||||
pytorch_root = pytorch_root.substr(0, pytorch_root.find_last_of("/"));
|
||||
std::string source_dir = pytorch_root + "/test/cpp/aoti_inference";
|
||||
|
||||
// Generate test data files (data.pt, etc.) by running test.py directly
|
||||
std::string test_script = source_dir + "/test.py";
|
||||
std::string test_data_cmd = "cd " + bindir + " && python " + test_script;
|
||||
std::cout << "Generating test data: " << test_data_cmd << std::endl;
|
||||
int result1 = std::system(test_data_cmd.c_str());
|
||||
if (result1 != 0) {
|
||||
std::cerr << "Warning: Test data generation failed with code " << result1
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Generate model files (script_*.pt) by running compile_model.py directly
|
||||
std::string compile_script = source_dir + "/compile_model.py";
|
||||
std::string models_cmd = "cd " + bindir + " && python " + compile_script;
|
||||
std::cout << "Generating model files: " << models_cmd << std::endl;
|
||||
int result2 = std::system(models_cmd.c_str());
|
||||
if (result2 != 0) {
|
||||
std::cerr << "Warning: Model generation failed with code " << result2
|
||||
<< std::endl;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, at::Tensor> derefTensorConstantMap(
|
||||
torch::inductor::TensorConstantMap tensor_constant_map) {
|
||||
std::unordered_map<std::string, at::Tensor> ret;
|
||||
@ -855,7 +915,6 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
void test_cuda_alloc_test() {
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
@ -895,8 +954,8 @@ void test_cuda_alloc_test() {
|
||||
runner->run(data_loader.attr(inputs_attr.c_str()).toTensorList().vec());
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
class ThreadPool {
|
||||
private:
|
||||
struct Task {
|
||||
@ -1037,86 +1096,96 @@ void test_multi_cuda_streams(const std::string& device) {
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], all_outputs[i][0]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // USE_CUDA
|
||||
#endif // USE_CUDA || USE_ROCM
|
||||
} // namespace
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
TEST(AotInductorTest, BasicTestCpu) {
|
||||
// Test fixture that ensures test data is generated once for all tests
|
||||
class AotInductorTest : public ::testing::Test {
|
||||
public:
|
||||
// This runs once before all tests in this test suite
|
||||
static void SetUpTestSuite() {
|
||||
ensureTestDataGenerated();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AotInductorTest, BasicTestCpu) {
|
||||
test_aoti("cpu", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicScriptTestCpu) {
|
||||
TEST_F(AotInductorTest, BasicScriptTestCpu) {
|
||||
test_aoti_script("cpu");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
test_aoti_package_loader("cpu", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
test_aoti_extract_constants_map("cpu");
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(AotInductorTest, BasicTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicTestCuda) {
|
||||
test_aoti("cuda", true);
|
||||
test_aoti("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicScriptTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicScriptTestCuda) {
|
||||
test_aoti_script("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
test_aoti_package_loader("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
test_aoti_package_loader_multi_gpu("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
test_aoti_user_managed_buffer();
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
TEST_F(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
test_aoti_extract_constants_map("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
TEST_F(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
test_aoti_double_buffering_with_tensor_constants();
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
test_aoti_free_buffer(false);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
test_aoti_free_buffer(true);
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, MultiStreamTestCuda) {
|
||||
TEST_F(AotInductorTest, MultiStreamTestCuda) {
|
||||
test_multi_cuda_streams("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, CudaAllocTestCuda) {
|
||||
TEST_F(AotInductorTest, CudaAllocTestCuda) {
|
||||
test_cuda_alloc_test();
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -168,7 +168,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
||||
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
||||
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
# test collectives
|
||||
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
||||
@ -176,7 +176,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
output = sharded_embedding(replicated_inp)
|
||||
self.assertIsInstance(output.placements[0], _MaskPartial)
|
||||
self.assertIsInstance(output.placements[0], MaskPartial)
|
||||
|
||||
comm_mode = CommDebugMode()
|
||||
|
||||
@ -192,9 +192,9 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
# case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial
|
||||
# case 1: two embeddings with the same shape, thus sharing the underlying MaskPartial
|
||||
# and MaskBuffer, because of cache hit from sharding propagation
|
||||
|
||||
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
|
||||
@ -206,23 +206,23 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
output2 = sharded_emb2(replicated_inp)
|
||||
|
||||
partial_placement1 = output1.placements[0]
|
||||
self.assertIsInstance(partial_placement1, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement1, MaskPartial)
|
||||
output1.full_tensor()
|
||||
|
||||
partial_placement2 = output2.placements[0]
|
||||
self.assertIsInstance(partial_placement2, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement2, MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
self.assertTrue(id(partial_placement1), id(partial_placement2))
|
||||
|
||||
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
|
||||
# thus they will have different _MaskPartial placements (with no cache hit)
|
||||
# thus they will have different MaskPartial placements (with no cache hit)
|
||||
|
||||
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
|
||||
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
|
||||
output3 = sharded_emb3(replicated_inp)
|
||||
partial_placement3 = output3.placements[0]
|
||||
self.assertIsInstance(partial_placement3, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement3, MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
# not equal because of different logical_shape, despite of same logical_dim_size
|
||||
|
||||
@ -511,7 +511,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
# case 2 input sharding: input sharded, index replicated, output mask partial
|
||||
# only works when index has size 1 on the gather dimension and
|
||||
# input is sharded on the gather dimension
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
gather_dim = 1
|
||||
global_input = torch.randn(12, 8, 16)
|
||||
@ -522,7 +522,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
with comm_mode:
|
||||
output_dt = torch.gather(input_dt, gather_dim, index_dt)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
|
||||
self.assertIsInstance(output_dt.placements[0], MaskPartial)
|
||||
self.assertEqual(output_dt.full_tensor(), global_output)
|
||||
|
||||
# case 3 index sharding: input replicated, index sharded, output sharded
|
||||
|
||||
@ -892,10 +892,16 @@ fn(torch.randn(5))
|
||||
os.remove(
|
||||
file_path
|
||||
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
orig_maxDiff = unittest.TestCase.maxDiff
|
||||
unittest.TestCase.maxDiff = None
|
||||
try:
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
except Exception:
|
||||
unittest.TestCase.maxDiff = orig_maxDiff
|
||||
raise
|
||||
|
||||
@make_settings_test("torch._dynamo.eval_frame")
|
||||
def test_log_traced_frames(self, records):
|
||||
|
||||
@ -1000,6 +1000,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_compiled_module_truthiness(self):
|
||||
# Test with empty ModuleList
|
||||
original_empty = nn.ModuleList()
|
||||
compiled_empty = torch.compile(original_empty)
|
||||
self.assertEqual(bool(original_empty), bool(compiled_empty))
|
||||
self.assertFalse(bool(compiled_empty))
|
||||
# Test with non-empty ModuleList
|
||||
original_filled = nn.ModuleList([nn.Linear(10, 5)])
|
||||
compiled_filled = torch.compile(original_filled)
|
||||
self.assertEqual(bool(original_filled), bool(compiled_filled))
|
||||
self.assertTrue(bool(compiled_filled))
|
||||
|
||||
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
|
||||
root = guard_manager_wrapper.root
|
||||
cloned_root = root.clone_manager(lambda x: True)
|
||||
|
||||
@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@config.patch(runtime_triton_nan_asserts=True)
|
||||
def test_nan_assert_inside_triton_kernel(self):
|
||||
def fn(x):
|
||||
x = x - 1
|
||||
# Uncomment the following line can trigger the failure of
|
||||
# the device size assertion
|
||||
# x = torch.log(x)
|
||||
return torch.where(x.isnan(), 3.14, x)
|
||||
|
||||
compiled = torch.compile(fn)
|
||||
x = torch.randn(4096, device=GPU_TYPE)
|
||||
out, (code,) = run_and_get_code(compiled, x)
|
||||
self.assertTrue("'NaN or Inf found'" in code)
|
||||
torch.testing.assert_close(out, fn(x))
|
||||
|
||||
@skip_if_cpp_wrapper("skip cpp wrapper")
|
||||
@requires_cuda_and_triton
|
||||
def test_repeat_interleave_decomposition_has_clamp(self):
|
||||
|
||||
@ -12,7 +12,6 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
expectedFailureMPS,
|
||||
expectedFailureMPSPre15,
|
||||
expectedFailureXLA,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
@ -173,7 +172,6 @@ class TestDropoutNNDeviceType(NNTestCase):
|
||||
else:
|
||||
self.assertNotEqual(permuted_inp, out)
|
||||
|
||||
@expectedFailureMPSPre15
|
||||
def test_Dropout(self, device):
|
||||
input = torch.empty(1000)
|
||||
self._test_dropout(nn.Dropout, device, input)
|
||||
|
||||
@ -529,7 +529,7 @@ class TestProfiler(TestCase):
|
||||
found_mm = True
|
||||
if "gemm" in e.name.lower() or "Cijk" in e.name:
|
||||
found_gemm = True
|
||||
if "memcpy" in e.name.lower():
|
||||
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
|
||||
found_memcpy = True
|
||||
if use_cuda:
|
||||
self.assertTrue(found_gemm)
|
||||
|
||||
@ -755,6 +755,8 @@ def run_test_retries(
|
||||
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
|
||||
) as f:
|
||||
current_failure = f.read()
|
||||
if current_failure == "null":
|
||||
current_failure = f"'{test_file}'"
|
||||
except FileNotFoundError:
|
||||
print_to_file(
|
||||
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
|
||||
@ -791,8 +793,6 @@ def run_test_retries(
|
||||
print_to_file("Retrying single test...")
|
||||
print_items = [] # do not continue printing them, massive waste of space
|
||||
|
||||
if "null" in num_failures:
|
||||
num_failures[f"'{test_file}'"] = num_failures.pop("null")
|
||||
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
|
||||
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
|
||||
if len(flaky_failures) > 0:
|
||||
|
||||
@ -7846,6 +7846,45 @@ class TestMPS(TestCaseMPS):
|
||||
y = torch.normal(torch.zeros(shape, device="mps"), torch.ones(shape, device="mps"))
|
||||
self.assertNotEqual(y[0], y[1])
|
||||
|
||||
def test_random_ops_noncontiguous(self):
|
||||
"""Test random in-place operations on non-contiguous tensors.
|
||||
|
||||
All random in-place operations should work on non-contiguous tensors.
|
||||
See issues #165257 and #124029.
|
||||
"""
|
||||
# Test each random in-place operation
|
||||
ops = [
|
||||
("normal_", lambda t: t.normal_(0, 1)),
|
||||
("uniform_", lambda t: t.uniform_(0, 1)),
|
||||
("exponential_", lambda t: t.exponential_(1.0)),
|
||||
("bernoulli_", lambda t: t.bernoulli_(0.5)),
|
||||
("random_", lambda t: t.random_()),
|
||||
("random_with_to", lambda t: t.random_(10)),
|
||||
("random_with_range", lambda t: t.random_(0, 10)),
|
||||
]
|
||||
|
||||
for name, op_func in ops:
|
||||
with self.subTest(operation=name):
|
||||
# Create non-contiguous tensor via transpose
|
||||
t_mps = torch.zeros(50, 50, device='mps').T.clone()
|
||||
self.assertFalse(t_mps.is_contiguous(),
|
||||
f"{name}: tensor should be non-contiguous")
|
||||
|
||||
# Apply operation
|
||||
op_func(t_mps)
|
||||
|
||||
# Verify tensor was modified (not all zeros)
|
||||
max_val = t_mps.max().item()
|
||||
self.assertNotEqual(max_val, 0.0,
|
||||
f"{name}: operation failed to modify non-contiguous tensor")
|
||||
|
||||
# Test rand_like specifically (issue #124029)
|
||||
t = torch.ones((3, 2, 2), device='mps').permute(2, 0, 1)
|
||||
self.assertFalse(t.is_contiguous(), "rand_like input should be non-contiguous")
|
||||
result = torch.rand_like(t)
|
||||
self.assertFalse(result.is_contiguous(), "rand_like result should be non-contiguous")
|
||||
self.assertNotEqual(result.max().item(), 0.0, "rand_like should generate non-zero values")
|
||||
|
||||
# Test exponential
|
||||
@unittest.skip("This does not test anything")
|
||||
def test_exponential(self):
|
||||
|
||||
@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
|
||||
_floatx_unpacked_to_f32,
|
||||
ceil_div, to_blocked,
|
||||
to_mxfp8,
|
||||
from_blocked_format,
|
||||
generate_jagged_offs,
|
||||
)
|
||||
|
||||
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
|
||||
uint8_data = uint8_data.contiguous().view(-1)
|
||||
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
|
||||
|
||||
def unpack_uint4(uint8_data) -> torch.Tensor:
|
||||
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
|
||||
# a tensor twice as wide. Useful for dequant operations.
|
||||
shape = list(uint8_data.shape)
|
||||
# 2x packed elements -> single non-packed => adjust shape
|
||||
shape[-1] *= 2
|
||||
out = torch.empty(
|
||||
*shape,
|
||||
device=uint8_data.device,
|
||||
dtype=torch.uint8
|
||||
).view(-1)
|
||||
|
||||
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
|
||||
|
||||
out[1::2] = uint8_data_as_uint8[:] >> 4
|
||||
out[::2] = uint8_data_as_uint8 & 15
|
||||
|
||||
return out.view(shape)
|
||||
|
||||
def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
assert x.dtype == torch.bfloat16
|
||||
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
x = x.view(torch.float4_e2m1fn_x2)
|
||||
return x
|
||||
|
||||
def _convert_to_nvfp4_with_hp_ref(t):
|
||||
# Convert a tensor to nvfp4, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : nvfp4 tensor (2x elements packed into uint8)
|
||||
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
|
||||
# t_global_scale: fp32 tensor-wise global scaling factor
|
||||
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
|
||||
t,
|
||||
16,
|
||||
)
|
||||
t_hp = from_blocked_format(
|
||||
_floatx_unpacked_to_f32(
|
||||
unpack_uint4(t_lp),
|
||||
FP4_EBITS,
|
||||
FP4_MBITS),
|
||||
t_scale,
|
||||
blocksize=16) * t_global_scale
|
||||
|
||||
return t_hp, t_lp, t_scale, t_global_scale
|
||||
|
||||
def _convert_to_mxfp8_with_hp_ref(t):
|
||||
# Convert a tensor to mxfp8, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : fp8_e4m3 tensor
|
||||
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
|
||||
t_scale, t_lp = to_mxfp8(t)
|
||||
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
|
||||
|
||||
return t_hp, t_lp, t_scale
|
||||
|
||||
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
|
||||
# Convert scales to blocked format. either mxfp8 or nvfp4
|
||||
th_list = []
|
||||
t_list = []
|
||||
t_blocked_scale_list = []
|
||||
t_global_scale_list = []
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else offs[group_idx - 1]
|
||||
)
|
||||
curr_group_end_offset = offs[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
t_slice = t[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
if format == 'mxfp8':
|
||||
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
|
||||
elif format == 'nvfp4':
|
||||
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
|
||||
t_slice,
|
||||
)
|
||||
t_global_scale_list.append(tq_global)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
t_list.append(tq_slice)
|
||||
th_list.append(th_slice)
|
||||
|
||||
# Convert scales to blocked format.
|
||||
t_scale_slice_blocked = to_blocked(
|
||||
t_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
t_blocked_scale_list.append(t_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
tq = torch.cat(t_list, dim=1).contiguous()
|
||||
th = torch.cat(th_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
|
||||
MN_rounded = round_up(MN, 128)
|
||||
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
|
||||
|
||||
# Global scales only exist for nvfp4
|
||||
t_global_scales = None
|
||||
if len(t_global_scale_list) > 0:
|
||||
t_global_scales = torch.stack(t_global_scale_list)
|
||||
|
||||
return th, tq, t_blocked_scales, t_global_scales
|
||||
|
||||
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
|
||||
# Build some standard args that are wordy
|
||||
# Note: if/when ROCm support added, need to change swizzle handling
|
||||
kwargs = {
|
||||
'mxfp8': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': ScalingType.BlockWise1x32,
|
||||
'scale_recipe_b': ScalingType.BlockWise1x32,
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
'nvfp4': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
}
|
||||
return kwargs[format]
|
||||
|
||||
class TestFP8Matmul(TestCase):
|
||||
|
||||
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [2048, 2049])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [16640])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
|
||||
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_list = []
|
||||
w_list = []
|
||||
x_blocked_scale_list = []
|
||||
w_blocked_scale_list = []
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
X, M, G, input_group_end_offsets, format=format
|
||||
)
|
||||
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
W, N, G, input_group_end_offsets, format=format
|
||||
)
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
curr_group_end_offset = input_group_end_offsets[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
x_slice = X[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
w_slice = W[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (N, K_group)
|
||||
x_scale_slice, xq_slice = to_mxfp8(
|
||||
x_slice
|
||||
) # scale shape -> (M, K_group // 32)
|
||||
w_scale_slice, wq_slice = to_mxfp8(
|
||||
w_slice
|
||||
) # scale shape -> (N, K_group // 32)
|
||||
x_list.append(xq_slice)
|
||||
w_list.append(wq_slice)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_scale_slice_blocked = to_blocked(
|
||||
x_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
w_scale_slice_blocked = to_blocked(
|
||||
w_scale_slice
|
||||
) # (round_up(N, 128), round_up(K_group//32, 4))
|
||||
x_blocked_scale_list.append(x_scale_slice_blocked)
|
||||
w_blocked_scale_list.append(w_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
xq = torch.cat(x_list, dim=1).contiguous()
|
||||
wq = torch.cat(w_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
||||
M_rounded = round_up(M, 128)
|
||||
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
||||
|
||||
# Combine all WQ groups blocked scales into one tensor.
|
||||
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
||||
N_rounded = round_up(N, 128)
|
||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute mxfp8 grouped mm output
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
xq, # (M, total_K)
|
||||
wq.transpose(-2, -1), # (total_K, N)
|
||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
offs=input_group_end_offsets, # (G,)
|
||||
out_dtype=torch.bfloat16,
|
||||
wrap_v2=wrap_v2
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# bf16 reference output
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Assert no NaNs
|
||||
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
|
||||
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
|
||||
|
||||
# Assert outputs are close
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [16640])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [4096])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||
# 2D inputs with groups along M, 3D weights.
|
||||
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
for i in range(G):
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
w_scale = to_blocked(w_scale)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
def _3d_to_blocked_scaled(W, G, format):
|
||||
wh_list = []
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
w_global_scale_list = []
|
||||
for i in range(G):
|
||||
if format == "mxfp8":
|
||||
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
|
||||
elif format == "nvfp4":
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
|
||||
w_global_scale_list.append(w_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Swizzle scaled
|
||||
# TODO(slayton): gate on cuda/hip
|
||||
w_scale = to_blocked(w_scale)
|
||||
|
||||
wh_list.append(wh)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wh = torch.stack(wh_list, dim=0).contiguous()
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
# Global scales only exist for nvfp4
|
||||
if len(w_global_scale_list) > 0:
|
||||
w_global_scales = torch.stack(w_global_scale_list)
|
||||
else:
|
||||
w_global_scales = None
|
||||
return wh, wq, w_scale, w_global_scales
|
||||
|
||||
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
|
||||
|
||||
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
x_scale, xq = to_mxfp8(x_slice)
|
||||
x_scale = to_blocked(x_scale)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
def _2d_to_blocked_scaled(X, K, G, offs, format):
|
||||
xh_list = []
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
x_global_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
if format == "mxfp8":
|
||||
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
|
||||
elif format == "nvfp4":
|
||||
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
|
||||
x_global_scale_list.append(x_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Compute mxfp8 grouped gemm.
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
x_scale = to_blocked(x_scale)
|
||||
xh_list.append(xh)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xh = torch.cat(xh_list, dim=0).contiguous()
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
xh = xh.view(-1, xh.shape[-1])
|
||||
|
||||
x_global_scales = None
|
||||
if len(x_global_scale_list) > 0:
|
||||
x_global_scales = torch.stack(x_global_scale_list)
|
||||
|
||||
return xh, xq, x_scale, x_global_scales
|
||||
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
|
||||
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute low-precision grouped gemm.
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
x_scale,
|
||||
w_scale,
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
wrap_v2=wrap_v2)
|
||||
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Compute reference bf16 grouped gemm.
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X,
|
||||
W.transpose(-2, -1),
|
||||
xh,
|
||||
wh.transpose(-2, -1),
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Assert outputs are close.
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("fast_accum", [False, True])
|
||||
# AMD does not support non-contiguous inputs yet
|
||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||
# AMD does not support NVFP4
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
|
||||
device = "cuda"
|
||||
|
||||
@ -1107,6 +1107,7 @@ class TestTransformers(NNTestCase):
|
||||
)[0]
|
||||
|
||||
@tf32_on_and_off(0.003)
|
||||
@parametrize("batch_size", [0, 5])
|
||||
@parametrize("input_dim,attn_mask_dim,is_causal",
|
||||
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
|
||||
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
|
||||
@ -1116,7 +1117,7 @@ class TestTransformers(NNTestCase):
|
||||
if attn_dim is not None else "no_attn_mask")))
|
||||
@parametrize("dropout_p", [0.0, 0.2, 0.5])
|
||||
@sdpa_kernel(backends=[SDPBackend.MATH])
|
||||
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
|
||||
def test_scaled_dot_product_attention(self, device, batch_size, input_dim, attn_mask_dim, is_causal, dropout_p):
|
||||
def sdp_ref(
|
||||
q,
|
||||
k,
|
||||
@ -1140,12 +1141,13 @@ class TestTransformers(NNTestCase):
|
||||
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
|
||||
dtypes = [torch.double, torch.float]
|
||||
for dtype in dtypes:
|
||||
N = batch_size
|
||||
|
||||
def rand_tensor(*shape):
|
||||
return torch.randn(shape, device=device, dtype=dtype)
|
||||
|
||||
# This test compares python and C++ implementations of SDP.
|
||||
N, N_prime, L, S, E = 5, 2, 4, 3, 6
|
||||
N_prime, L, S, E = 2, 4, 3, 6
|
||||
if input_dim == 3:
|
||||
query = rand_tensor(N, L, E)
|
||||
key = rand_tensor(N, S, E)
|
||||
|
||||
@ -5,11 +5,12 @@ from collections import namedtuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention import varlen_attn
|
||||
from torch.nn.attention.varlen import varlen_attn
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
VarlenShape = namedtuple(
|
||||
@ -23,6 +24,18 @@ default_tolerances = {
|
||||
}
|
||||
|
||||
|
||||
class OpLoggingMode(TorchDispatchMode):
|
||||
"""Logging mode that captures all dispatched operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.called_ops = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
op_name = str(func)
|
||||
self.called_ops.append(op_name)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
|
||||
@ -39,12 +52,9 @@ class AttentionBlock(nn.Module):
|
||||
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward_varlen(
|
||||
def get_varlen_qkv(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
qkv = self.qkv_proj(x_packed)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
@ -53,24 +63,51 @@ class AttentionBlock(nn.Module):
|
||||
k = k.view(-1, self.num_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
attn_out = varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
q, k, v = self.get_varlen_qkv(x_packed)
|
||||
|
||||
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
|
||||
attn_out = attn_out.view(-1, self.embed_dim)
|
||||
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
|
||||
def forward_sdpa(
|
||||
self,
|
||||
x_padded: torch.Tensor,
|
||||
seq_lengths: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
batch_size, seq_len, _ = x_padded.shape
|
||||
|
||||
qkv = self.qkv_proj(x_padded)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
mask = (
|
||||
torch.arange(seq_len, device=x_padded.device)[None, :]
|
||||
< seq_lengths[:, None]
|
||||
)
|
||||
|
||||
attn_mask = mask[:, None, None, :].expand(
|
||||
batch_size, self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=is_causal
|
||||
)
|
||||
|
||||
attn_out = (
|
||||
attn_out.transpose(1, 2)
|
||||
.contiguous()
|
||||
@ -91,7 +128,9 @@ def create_variable_length_batch(
|
||||
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||
total_tokens = seq_lengths.sum().item()
|
||||
|
||||
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||
@ -106,6 +145,7 @@ def create_variable_length_batch(
|
||||
end_idx = start_idx + seq_len
|
||||
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
x_padded = x_padded.clone().detach().requires_grad_()
|
||||
|
||||
return {
|
||||
"seq_lengths": seq_lengths,
|
||||
@ -133,7 +173,11 @@ class TestVarlenAttention(NNTestCase):
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
@ -147,6 +191,128 @@ class TestVarlenAttention(NNTestCase):
|
||||
self.assertEqual(output.device, torch.device(device))
|
||||
self.assertEqual(output.dtype, dtype)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
self.assertIsNotNone(varlen_grad)
|
||||
self.assertEqual(varlen_grad.shape, x_packed.shape)
|
||||
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_compliance(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
q, k, v = attention_block.get_varlen_qkv(x_packed)
|
||||
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn,
|
||||
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
|
||||
)
|
||||
|
||||
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
|
||||
)
|
||||
grad_out = torch.randn_like(out)
|
||||
|
||||
# we don't support double backward
|
||||
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn_backward,
|
||||
(
|
||||
grad_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
cu_seq,
|
||||
cu_seq,
|
||||
shape.max_seq_len,
|
||||
shape.max_seq_len,
|
||||
False,
|
||||
rng_state,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_registration(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
compiled_forward = torch.compile(
|
||||
attention_block.forward_varlen, backend="eager", fullgraph=True
|
||||
)
|
||||
with OpLoggingMode() as mode:
|
||||
output = compiled_forward(
|
||||
x_packed, cu_seq, shape.max_seq_len, is_causal=False
|
||||
)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
_ = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
called_ops = mode.called_ops
|
||||
|
||||
custom_ops_called = any(
|
||||
"torch_attn._varlen_attn" in op for op in called_ops
|
||||
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
|
||||
assert custom_ops_called
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@ -172,7 +338,10 @@ class TestVarlenAttention(NNTestCase):
|
||||
is_causal=is_causal,
|
||||
)
|
||||
sdpa_output = attention_block.forward_sdpa(
|
||||
variable_length_batch_data["x_padded"], is_causal=is_causal
|
||||
variable_length_batch_data["x_padded"],
|
||||
variable_length_batch_data["seq_lengths"],
|
||||
dtype=dtype,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
tolerances = default_tolerances[dtype]
|
||||
@ -186,6 +355,44 @@ class TestVarlenAttention(NNTestCase):
|
||||
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad_out = torch.ones_like(varlen_output)
|
||||
|
||||
sdpa_grad_out = torch.zeros_like(sdpa_output)
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=varlen_output,
|
||||
inputs=variable_length_batch_data["x_packed"],
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
sdpa_grad = torch.autograd.grad(
|
||||
outputs=sdpa_output,
|
||||
inputs=variable_length_batch_data["x_padded"],
|
||||
grad_outputs=sdpa_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
|
||||
varlen_grad_seq = varlen_grad[start_idx:end_idx]
|
||||
sdpa_grad_seq = sdpa_grad[i, :seq_len]
|
||||
|
||||
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
device_types = ("cuda",)
|
||||
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: a6b2477b88...6fcbc53d33
@ -445,7 +445,7 @@ use_numpy_random_stream = False
|
||||
enable_cpp_guard_manager = True
|
||||
|
||||
# Use C++ guard manager for symbolic shapes
|
||||
enable_cpp_symbolic_shape_guards = not is_fbcode()
|
||||
enable_cpp_symbolic_shape_guards = False
|
||||
|
||||
# Enable tracing through contextlib.contextmanager
|
||||
enable_trace_contextlib = True
|
||||
|
||||
@ -42,7 +42,7 @@ import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -395,6 +395,13 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._initialize()
|
||||
self.training = self._orig_mod.training
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Proxy the len call to the original module
|
||||
if isinstance(self._orig_mod, Sized):
|
||||
return len(self._orig_mod)
|
||||
# Mimic python's default behavior for objects without a length
|
||||
raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
|
||||
|
||||
def _initialize(self) -> None:
|
||||
# Do this stuff in constructor to lower overhead slightly
|
||||
if isinstance(self.dynamo_ctx, DisableContext):
|
||||
|
||||
@ -2820,5 +2820,36 @@
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0280": [
|
||||
{
|
||||
"Gb_type": "1-arg super not implemented",
|
||||
"Context": "",
|
||||
"Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because one-argument of super() is not supported.",
|
||||
"Hints": [
|
||||
"Use two-argument super(type, object_or_type)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0281": [
|
||||
{
|
||||
"Gb_type": "Invalid or non-const argument in nn.Module __getitem__",
|
||||
"Context": "call_method: {self} {name} {args} {kwargs}",
|
||||
"Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
|
||||
"Hints": [
|
||||
"Use constant arguments of type str or int for __getitem__"
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0282": [
|
||||
{
|
||||
"Gb_type": "Placement with custom __getattr__ not supported",
|
||||
"Context": "{value_type.__name__} with custom __getattr__",
|
||||
"Explanation": "Dynamo does not support Placement types with custom __getattr__ methods",
|
||||
"Hints": [
|
||||
"Use Placement types without custom __getattr__ methods",
|
||||
"Move the Placement usage outside the compiled region"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -210,9 +210,16 @@ class PlacementVariable(DistributedVariable):
|
||||
if name in constant_fold_functions:
|
||||
try:
|
||||
value_type = type(self.value)
|
||||
assert (
|
||||
inspect.getattr_static(value_type, "__getattr__", None) is None
|
||||
), "no custom getattr allowed!"
|
||||
if inspect.getattr_static(value_type, "__getattr__", None) is not None:
|
||||
unimplemented_v2(
|
||||
gb_type="Placement with custom __getattr__ not supported",
|
||||
context=f"{value_type.__name__} with custom __getattr__",
|
||||
explanation="Dynamo does not support Placement types with custom __getattr__ methods",
|
||||
hints=[
|
||||
"Use Placement types without custom __getattr__ methods",
|
||||
"Move the Placement usage outside the compiled region",
|
||||
],
|
||||
)
|
||||
method = inspect.getattr_static(value_type, name)
|
||||
except AttributeError:
|
||||
method = None
|
||||
|
||||
@ -103,7 +103,17 @@ class SuperVariable(VariableTracker):
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
|
||||
assert self.objvar, "1-arg super not implemented"
|
||||
if not self.objvar:
|
||||
unimplemented_v2(
|
||||
gb_type="1-arg super not implemented",
|
||||
context="",
|
||||
explanation=f"Dynamo failed to trace attribute `{name}` accessed "
|
||||
f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
|
||||
"because one-argument of super() is not supported.",
|
||||
hints=[
|
||||
"Use two-argument super(type, object_or_type).",
|
||||
],
|
||||
)
|
||||
search_type = self.typevar.as_python_constant()
|
||||
|
||||
# The rest of this function does two things:
|
||||
|
||||
@ -822,9 +822,19 @@ class NNModuleVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if type(module).__getitem__ not in builtin_supported:
|
||||
assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
|
||||
key = args[0].as_python_constant()
|
||||
assert isinstance(key, (str, int))
|
||||
if not (
|
||||
isinstance(args[0], variables.ConstantVariable)
|
||||
and isinstance(args[0].as_python_constant(), (str, int))
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Invalid or non-const argument in nn.Module __getitem__",
|
||||
context=f"call_method: {self} {name} {args} {kwargs}",
|
||||
explanation="Dynamo does not support calling "
|
||||
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
|
||||
hints=[
|
||||
"Use constant arguments of type str or int for __getitem__"
|
||||
],
|
||||
)
|
||||
fn = getattr(module, name).__func__
|
||||
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
@ -1793,14 +1793,6 @@ def _aot_stage2b_bw_compile(
|
||||
# tensor which is wrong.
|
||||
|
||||
ph_size = ph_arg.size()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if len(ph_size) == 0 and len(real_stride) > 0:
|
||||
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
|
||||
# (e.g., via squeeze), its stride should be () not (1,).
|
||||
# This mismatch can occur when dynamic shape operations produce
|
||||
# tensors that are later squeezed to 0-d. The stride metadata
|
||||
# may get preserved causing a dimension mismatch (#164814)
|
||||
real_stride = ()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride)
|
||||
|
||||
@ -720,13 +720,22 @@ def check_shape(
|
||||
) -> None:
|
||||
backend = get_current_backend()
|
||||
assert shape is not None
|
||||
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
|
||||
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
|
||||
shape_str = (
|
||||
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
|
||||
)
|
||||
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
|
||||
|
||||
|
||||
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
|
||||
backend = get_current_backend()
|
||||
if backend == "triton":
|
||||
msg = "NaN or Inf found"
|
||||
buffer.writeline(
|
||||
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
|
||||
)
|
||||
|
||||
|
||||
class DataTypePropagation:
|
||||
def __init__(self, body: LoopBody) -> None:
|
||||
self.body = body
|
||||
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
|
||||
assert output_shape is not None
|
||||
check_shape(V.kernel.compute, csevar, output_shape)
|
||||
|
||||
if config.runtime_triton_nan_asserts:
|
||||
check_nan(V.kernel.compute, csevar)
|
||||
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
@ -626,7 +626,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -206,6 +206,9 @@ static_weight_shapes = True
|
||||
# put correctness assertions in generated code
|
||||
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
||||
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
|
||||
runtime_triton_nan_asserts = (
|
||||
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
|
||||
)
|
||||
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
|
||||
|
||||
# Disable by default in fbcode
|
||||
|
||||
@ -3550,24 +3550,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not inductor_meta.get("autotune_pointwise", True) and not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -409,9 +409,10 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
ancestors: OrderedSet[str]
|
||||
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
|
||||
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
last_usage: OrderedSet[str]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
@ -420,22 +421,24 @@ class BaseSchedulerNode:
|
||||
min_order: int
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
mutation_renames: dict[str, str]
|
||||
node: Optional[ir.Operation]
|
||||
outputs: list[SchedulerBuffer]
|
||||
outputs_by_name: dict[str, SchedulerBuffer]
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
self.scheduler = scheduler
|
||||
self.debug_device_str = lambda *args, **kwargs: []
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.node = node
|
||||
self.ancestors = OrderedSet()
|
||||
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs: list[SchedulerBuffer] = [
|
||||
self.outputs = [
|
||||
SchedulerBuffer(
|
||||
scheduler=self.scheduler,
|
||||
node=output,
|
||||
@ -443,16 +446,14 @@ class BaseSchedulerNode:
|
||||
)
|
||||
for output in node.get_outputs()
|
||||
]
|
||||
self.outputs_by_name: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for buf in self.outputs
|
||||
}
|
||||
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
|
||||
|
||||
# mutation_renames for the current node. Due to potential
|
||||
# more mutations happening later, this can be different
|
||||
# to Scheduler.mutation_renames. Also this dict should be small
|
||||
# since only mutation information relevant to the deps for this
|
||||
# node is stored here.
|
||||
self.mutation_renames: dict[str, str] = {}
|
||||
self.mutation_renames = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(name={self.get_name()!r})"
|
||||
@ -2435,6 +2436,34 @@ def pick_loop_order(
|
||||
return order
|
||||
|
||||
|
||||
def _replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeUser:
|
||||
node: Union[BaseSchedulerNode, OutputNode]
|
||||
@ -3336,33 +3365,6 @@ class Scheduler:
|
||||
will force completion of compilation and benchmarking.
|
||||
"""
|
||||
|
||||
def replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, ir.MultiTemplateBuffer
|
||||
@ -3416,40 +3418,47 @@ class Scheduler:
|
||||
assign_origin_node(out_tensorbox, multi_node.origin_node)
|
||||
|
||||
out_buffer.layout = multi_node.layout
|
||||
replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
self._replace_node(out_buffer, multi_node, i, node)
|
||||
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
def _replace_node(
|
||||
self,
|
||||
out_buffer: ir.OperationBuffer,
|
||||
multi_node: ir.MultiTemplateBuffer,
|
||||
i: int,
|
||||
node: SchedulerNode,
|
||||
) -> None:
|
||||
_replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(
|
||||
node.read_writes.reads, node.unmet_dependencies
|
||||
):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
|
||||
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
|
||||
return any(
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
@ -19,8 +15,8 @@ from torch.distributed.tensor._ops.utils import (
|
||||
register_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
MaskPartial,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
@ -29,190 +25,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
def materialize_mask(self, mask):
|
||||
if self.refcount == 0:
|
||||
self.data = mask
|
||||
else:
|
||||
assert self.data is not None
|
||||
if not torch.equal(self.data, mask):
|
||||
raise RuntimeError(
|
||||
"MaskBuffer has been materialized with conflicting data"
|
||||
)
|
||||
self.refcount += 1
|
||||
|
||||
def release_mask(self):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
self.refcount -= 1
|
||||
if self.refcount == 0:
|
||||
self.data = None
|
||||
|
||||
def apply_mask(self, tensor):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
|
||||
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
|
||||
# For gather, the mask has the same dimension as the output tensor, whereas
|
||||
# the output of the embedding op has an additional dimension compare to the input,
|
||||
# hence the output masking logic below having two different cases.
|
||||
if tensor.ndim == self.data.ndim:
|
||||
tensor[self.data] = 0.0
|
||||
else:
|
||||
tensor[self.data, :] = 0.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _MaskPartial(Partial):
|
||||
"""
|
||||
A partial mask placement devised for rowwise sharded embedding op, where we need
|
||||
to mask and adjust the indices to the local embedding shard, embedding masking
|
||||
is a special type of the Partial placement
|
||||
|
||||
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
|
||||
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
|
||||
"""
|
||||
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_op=None,
|
||||
mask_buffer=None,
|
||||
offset_shape=None,
|
||||
offset_dim=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(reduce_op)
|
||||
if mask_buffer is None:
|
||||
mask_buffer = MaskBuffer()
|
||||
object.__setattr__(self, "mask_buffer", mask_buffer)
|
||||
object.__setattr__(self, "offset_shape", offset_shape)
|
||||
object.__setattr__(self, "offset_dim", offset_dim)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
)
|
||||
mask, masked_tensor = _MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
)
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
||||
def _reduce_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# perform sum reduction
|
||||
return funcol.all_reduce(
|
||||
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
|
||||
)
|
||||
|
||||
def _reduce_shard_value(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
shard_spec: Placement,
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# call reduce_shard_tensor of the shard_spec.
|
||||
shard_spec = cast(Shard, shard_spec)
|
||||
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _MaskPartial):
|
||||
return False
|
||||
|
||||
# if either data is not None, we invalidate the sharding cache, as this indicates
|
||||
# the current MaskPartial placement is still in use and should not be used for cache hit.
|
||||
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.reduce_op == other.reduce_op
|
||||
and self.offset_shape == other.offset_shape
|
||||
and self.offset_dim == other.offset_dim
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(
|
||||
(
|
||||
self.reduce_op,
|
||||
self.offset_shape,
|
||||
self.offset_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the MaskPartial placement
|
||||
"""
|
||||
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
human readable representation of the MaskPartial placement
|
||||
"""
|
||||
return "MaskP"
|
||||
|
||||
|
||||
@register_op_strategy(aten.embedding.default)
|
||||
def embedding_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
"""
|
||||
@ -239,7 +51,7 @@ def embedding_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
single_mesh_dim_strategies.append(colwise_sharding)
|
||||
|
||||
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
|
||||
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
|
||||
embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0)
|
||||
|
||||
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
|
||||
# from the input indices and use it for output reduction
|
||||
|
||||
44
torch/distributed/tensor/_ops/_mask_buffer.py
Normal file
44
torch/distributed/tensor/_ops/_mask_buffer.py
Normal file
@ -0,0 +1,44 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
def materialize_mask(self, mask):
|
||||
if self.refcount == 0:
|
||||
self.data = mask
|
||||
else:
|
||||
assert self.data is not None
|
||||
if not torch.equal(self.data, mask):
|
||||
raise RuntimeError(
|
||||
"MaskBuffer has been materialized with conflicting data"
|
||||
)
|
||||
self.refcount += 1
|
||||
|
||||
def release_mask(self):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
self.refcount -= 1
|
||||
if self.refcount == 0:
|
||||
self.data = None
|
||||
|
||||
def apply_mask(self, tensor):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
|
||||
# NOTE: MaskPartial is being used by the embedding op and the gather op.
|
||||
# For gather, the mask has the same dimension as the output tensor, whereas
|
||||
# the output of the embedding op has an additional dimension compare to the input,
|
||||
# hence the output masking logic below having two different cases.
|
||||
if tensor.ndim == self.data.ndim:
|
||||
tensor[self.data] = 0.0
|
||||
else:
|
||||
tensor[self.data, :] = 0.0
|
||||
@ -17,7 +17,7 @@ from torch.distributed.tensor._op_schema import (
|
||||
TupleStrategy,
|
||||
)
|
||||
from torch.distributed.tensor._ops._common_rules import pointwise_rule
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
|
||||
from torch.distributed.tensor._ops.utils import (
|
||||
expand_to_full_mesh_op_strategy,
|
||||
generate_redistribute_costs,
|
||||
@ -646,7 +646,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# this only works when the input is sharded on the gather dimension, and
|
||||
# index has size 1 on the gather dimension
|
||||
if dim < len(index_shape) and index_shape[dim] == 1:
|
||||
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
|
||||
index_partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=dim)
|
||||
input_sharding: PlacementList = [
|
||||
index_partial_placement,
|
||||
Shard(dim),
|
||||
|
||||
@ -11,7 +11,7 @@ from torch import Tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
|
||||
from torch.distributed.tensor._ops._math_ops import (
|
||||
_skip_dim,
|
||||
Reduction,
|
||||
@ -236,7 +236,7 @@ def _nll_loss_forward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target_partial_ = partial_placement._partition_value(
|
||||
safe_target_, mesh, mesh_dim
|
||||
)
|
||||
@ -375,7 +375,7 @@ def _nll_loss_and_log_softmax_backward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target = safe_target.squeeze(channel_dim).flatten()
|
||||
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
|
||||
# only update grad_input to -1 if not masked
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
@ -17,9 +18,10 @@ from torch.distributed.tensor._collective_utils import (
|
||||
shard_dim_alltoall,
|
||||
unpad_tensor,
|
||||
)
|
||||
from torch.distributed.tensor._ops._mask_buffer import MaskBuffer
|
||||
|
||||
|
||||
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
|
||||
__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"]
|
||||
|
||||
|
||||
# Appease TestPublicBindings.test_correct_module_names
|
||||
@ -841,3 +843,149 @@ class Partial(torch._C._distributed.Partial):
|
||||
|
||||
# We keep the old _Partial name for a while for BC reason
|
||||
_Partial = Partial
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MaskPartial(Partial):
|
||||
"""
|
||||
A partial mask placement devised for rowwise sharded embedding op, where we need
|
||||
to mask and adjust the indices to the local embedding shard, embedding masking
|
||||
is a special type of the Partial placement
|
||||
|
||||
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
|
||||
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
|
||||
"""
|
||||
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_op=None,
|
||||
mask_buffer=None,
|
||||
offset_shape=None,
|
||||
offset_dim=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(reduce_op)
|
||||
if mask_buffer is None:
|
||||
mask_buffer = MaskBuffer()
|
||||
object.__setattr__(self, "mask_buffer", mask_buffer)
|
||||
object.__setattr__(self, "offset_shape", offset_shape)
|
||||
object.__setattr__(self, "offset_dim", offset_dim)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
)
|
||||
mask, masked_tensor = MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
)
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
||||
def _reduce_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# perform sum reduction
|
||||
return funcol.all_reduce(
|
||||
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
|
||||
)
|
||||
|
||||
def _reduce_shard_value(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
shard_spec: Placement,
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# call reduce_shard_tensor of the shard_spec.
|
||||
shard_spec = cast(Shard, shard_spec)
|
||||
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, MaskPartial):
|
||||
return False
|
||||
|
||||
# if either data is not None, we invalidate the sharding cache, as this indicates
|
||||
# the current MaskPartial placement is still in use and should not be used for cache hit.
|
||||
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.reduce_op == other.reduce_op
|
||||
and self.offset_shape == other.offset_shape
|
||||
and self.offset_dim == other.offset_dim
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(
|
||||
(
|
||||
self.reduce_op,
|
||||
self.offset_shape,
|
||||
self.offset_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the MaskPartial placement
|
||||
"""
|
||||
return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
human readable representation of the MaskPartial placement
|
||||
"""
|
||||
return "MaskP"
|
||||
|
||||
@ -14,14 +14,11 @@ from torch.backends.cuda import (
|
||||
SDPAParams,
|
||||
)
|
||||
|
||||
from .varlen import varlen_attn
|
||||
|
||||
|
||||
__all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"varlen_attn",
|
||||
]
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
|
||||
@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -33,8 +33,7 @@ class AuxRequest(NamedTuple):
|
||||
lse: bool = False
|
||||
|
||||
|
||||
# import failures when I try to register as custom op
|
||||
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
|
||||
def _varlen_attn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -44,7 +43,7 @@ def _varlen_attn(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Private custom op for variable-length attention.
|
||||
|
||||
@ -70,7 +69,7 @@ def _varlen_attn(
|
||||
False, # return_debug_mask
|
||||
)
|
||||
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
|
||||
output, softmax_lse = result[0], result[1]
|
||||
output, softmax_lse, rng_state = result[0], result[1], result[6]
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
|
||||
@ -86,10 +85,13 @@ def _varlen_attn(
|
||||
return_debug_mask=False,
|
||||
)
|
||||
|
||||
return output, softmax_lse
|
||||
rng_state_ = torch.zeros(
|
||||
(2,), dtype=torch.uint64, device=query.device
|
||||
) # hardcoded since dropout is hardcoded to 0
|
||||
return output, softmax_lse, rng_state_
|
||||
|
||||
|
||||
# @_varlen_attn.register_fake
|
||||
@_varlen_attn.register_fake
|
||||
def _varlen_attn_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -99,7 +101,7 @@ def _varlen_attn_fake(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
|
||||
@ -117,7 +119,9 @@ def _varlen_attn_fake(
|
||||
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||
)
|
||||
|
||||
return output, logsumexp
|
||||
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
|
||||
|
||||
return output, logsumexp, rng_state
|
||||
|
||||
|
||||
def varlen_attn(
|
||||
@ -191,9 +195,145 @@ def varlen_attn(
|
||||
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
|
||||
... )
|
||||
"""
|
||||
out, lse = _varlen_attn(
|
||||
out, lse, _ = torch.ops.torch_attn._varlen_attn(
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
|
||||
)
|
||||
if return_aux is not None and return_aux.lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
|
||||
out, lse, rng_state = output
|
||||
ctx.query = query
|
||||
ctx.key = key
|
||||
ctx.value = value
|
||||
ctx.cu_seq_q = cu_seq_q
|
||||
ctx.cu_seq_k = cu_seq_k
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.is_causal = is_causal
|
||||
ctx.output = out
|
||||
ctx.lse = lse
|
||||
ctx.rng_state = rng_state
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
|
||||
def _varlen_attn_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
unused = torch.empty(0, device=query.device)
|
||||
|
||||
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
|
||||
if use_cudnn:
|
||||
log.info("Using cuDNN backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._flash_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
@_varlen_attn_backward.register_fake
|
||||
def _varlen_attn_backward_fake(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
"""
|
||||
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _backward(
|
||||
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
query = ctx.query
|
||||
key = ctx.key
|
||||
value = ctx.value
|
||||
cu_seq_q = ctx.cu_seq_q
|
||||
cu_seq_k = ctx.cu_seq_k
|
||||
max_q = ctx.max_q
|
||||
max_k = ctx.max_k
|
||||
is_causal = ctx.is_causal
|
||||
out = ctx.output
|
||||
lse = ctx.lse
|
||||
rng_state = ctx.rng_state
|
||||
|
||||
# rng_state = torch.empty(2, device=query.device)
|
||||
|
||||
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
is_causal,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None
|
||||
|
||||
|
||||
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)
|
||||
|
||||
@ -74,6 +74,17 @@ def export_compat(
|
||||
if opset_version is None:
|
||||
opset_version = onnx_constants.ONNX_DEFAULT_OPSET
|
||||
|
||||
if isinstance(model, torch.nn.Module):
|
||||
if model.training:
|
||||
warnings.warn(
|
||||
"Exporting a model while it is in training mode. "
|
||||
"Please ensure that this is intended, as it may lead to "
|
||||
"different behavior during inference. "
|
||||
"Calling model.eval() before export is recommended.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(model, torch.export.ExportedProgram):
|
||||
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes
|
||||
# are not used
|
||||
|
||||
@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
|
||||
"__rmod__",
|
||||
"__rsub__",
|
||||
"__rpow__",
|
||||
"bernoulli",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"masked_scatter",
|
||||
|
||||
@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
|
||||
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
|
||||
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
|
||||
# more naturally.
|
||||
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
|
||||
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
|
||||
# Output should be [input.size(0, input.size(1) // blocksize] scales
|
||||
output_scales = torch.zeros(
|
||||
(input.size(0), input.size(1) // blocksize),
|
||||
device=input.device,
|
||||
dtype=input_scales.dtype,
|
||||
)
|
||||
|
||||
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
|
||||
# happened for offset purposes.
|
||||
# There are K//blocksize scales, padded to groups of 4.
|
||||
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
|
||||
|
||||
# (Very) slow reference implementation using horrifying loops.
|
||||
for i in range(input.size(0)):
|
||||
for j in range(input.size(1) // blocksize):
|
||||
# which 128x4 tile of scaling factors am I in
|
||||
scale_tile_h = i // 128
|
||||
scale_tile_w = j // 4
|
||||
|
||||
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
|
||||
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
|
||||
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
|
||||
|
||||
# indices within the tile - use nomenclature directly from cublas docs
|
||||
outer = i % 128 # "outer" in cublas docs
|
||||
inner = j % 4 # "inner" in cublas docs
|
||||
|
||||
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
|
||||
# anyway, and so 1B == 1 value => use offset directly.
|
||||
# Formula directly from cublas docs in 3.1.4.3.2
|
||||
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
|
||||
|
||||
output_scales[i, j] = input_scales[offset]
|
||||
|
||||
return output_scales
|
||||
|
||||
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
|
||||
# expand scales
|
||||
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
|
||||
|
||||
# de-scale and convert
|
||||
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
|
||||
return x_f32.to(torch.bfloat16)
|
||||
|
||||
def to_blocked(input_matrix) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
|
||||
@ -7,7 +7,8 @@ import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, TypeVar
|
||||
from typing import Any, cast, overload, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
||||
@ -158,7 +159,12 @@ class _DecoratorContextManager:
|
||||
class _NoParamDecoratorContextManager(_DecoratorContextManager):
|
||||
"""Allow a context manager to be used as a decorator without parentheses."""
|
||||
|
||||
def __new__(cls, orig_func=None):
|
||||
@overload
|
||||
def __new__(cls, orig_func: F) -> F: ... # type: ignore[misc]
|
||||
@overload
|
||||
def __new__(cls, orig_func: None = None) -> Self: ...
|
||||
|
||||
def __new__(cls, orig_func: F | None = None) -> Self | F: # type: ignore[misc]
|
||||
if orig_func is None:
|
||||
return super().__new__(cls)
|
||||
return cls()(orig_func)
|
||||
|
||||
@ -311,7 +311,11 @@ def escape(n):
|
||||
|
||||
|
||||
def is_cuda_tensor(obj):
|
||||
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
return (
|
||||
isinstance(obj, torch.Tensor) and
|
||||
obj.device.type == "cuda" and
|
||||
not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
)
|
||||
|
||||
def cuda_allocation_context():
|
||||
snapshot = torch.cuda.memory._snapshot()
|
||||
|
||||
Reference in New Issue
Block a user