mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
Compare commits
1 Commits
ciflow/ind
...
revert-cpp
| Author | SHA1 | Date | |
|---|---|---|---|
| 2eacbe792a |
@ -100,8 +100,6 @@ COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
# Only build aoti cpp tests when INDUCTOR_BENCHMARKS is set to True
|
||||
ENV BUILD_AOT_INDUCTOR_TEST ${INDUCTOR_BENCHMARKS}
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
|
||||
@ -460,18 +460,28 @@ test_inductor_shard() {
|
||||
--verbose
|
||||
}
|
||||
|
||||
test_inductor_aoti_cpp() {
|
||||
test_inductor_aoti() {
|
||||
# docker build uses bdist_wheel which does not work with test_aot_inductor
|
||||
# TODO: need a faster way to build
|
||||
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
|
||||
# We need to hipify before building again
|
||||
python3 tools/amd_build/build_amd.py
|
||||
fi
|
||||
if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then
|
||||
BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python -m pip install --no-build-isolation -v -e .)
|
||||
# TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}")
|
||||
else
|
||||
BUILD_COMMAND=(python -m pip install --no-build-isolation -v -e .)
|
||||
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}")
|
||||
fi
|
||||
|
||||
# aoti cmake custom command requires `torch` to be installed
|
||||
# initialize the cmake build cache and install torch
|
||||
/usr/bin/env "${BUILD_COMMAND[@]}"
|
||||
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
|
||||
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
|
||||
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
@ -1766,7 +1776,7 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
|
||||
install_torchvision
|
||||
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
|
||||
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
|
||||
test_inductor_aoti_cpp
|
||||
test_inductor_aoti
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
install_torchvision
|
||||
|
||||
@ -1358,15 +1358,9 @@ if(BUILD_TEST)
|
||||
)
|
||||
else()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy ${CMAKE_BINARY_DIR}/test_lazy)
|
||||
# NativeRT is disabled
|
||||
# add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_abi_check ${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_inference ${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
|
||||
if(USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
|
||||
if(NOT WIN32)
|
||||
@ -1384,6 +1378,16 @@ if(BUILD_TEST)
|
||||
${CMAKE_BINARY_DIR}/test_mobile_nnc
|
||||
)
|
||||
endif()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy
|
||||
${CMAKE_BINARY_DIR}/test_lazy)
|
||||
endif()
|
||||
if(BUILD_AOT_INDUCTOR_TEST)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_abi_check
|
||||
${CMAKE_BINARY_DIR}/test_aoti_abi_check)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/aoti_inference
|
||||
${CMAKE_BINARY_DIR}/test_aoti_inference)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -1,8 +1,3 @@
|
||||
# Skip on windows
|
||||
if(WIN32)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
@ -35,15 +30,8 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are written in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main sleef)
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Disable unused-variable warnings for variables that are only used to test compilation
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
|
||||
|
||||
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
|
||||
@ -53,17 +41,12 @@ foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main sleef)
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${CMAKE_BINARY_DIR}/include)
|
||||
endif()
|
||||
|
||||
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
|
||||
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
|
||||
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-variable)
|
||||
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-but-set-variable)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -2,27 +2,10 @@
|
||||
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
#include <iostream>
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
template <typename T>
|
||||
void ExpectVecEqual(
|
||||
const at::vec::Vectorized<T>& expected,
|
||||
const at::vec::Vectorized<T>& actual) {
|
||||
using Vec = at::vec::Vectorized<T>;
|
||||
// Have to use std::vector for comparison because at::vec::Vectorized doesn't
|
||||
// support operator[] on aarch64
|
||||
std::vector<T> expected_data(Vec::size());
|
||||
std::vector<T> actual_data(Vec::size());
|
||||
|
||||
expected.store(expected_data.data());
|
||||
actual.store(actual_data.data());
|
||||
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_data[i], actual_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestAdd) {
|
||||
using Vec = at::vec::Vectorized<int>;
|
||||
std::vector<int> a(1024, 1);
|
||||
@ -33,7 +16,9 @@ TEST(TestVec, TestAdd) {
|
||||
std::vector<int> expected(1024, 3);
|
||||
Vec expected_vec = Vec::loadu(expected.data());
|
||||
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMax) {
|
||||
@ -45,7 +30,9 @@ TEST(TestVec, TestMax) {
|
||||
Vec actual_vec = at::vec::maximum(a_vec, b_vec);
|
||||
Vec expected_vec = b_vec;
|
||||
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestMin) {
|
||||
@ -57,7 +44,9 @@ TEST(TestVec, TestMin) {
|
||||
Vec actual_vec = at::vec::minimum(a_vec, b_vec);
|
||||
Vec expected_vec = a_vec;
|
||||
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestConvert) {
|
||||
@ -69,7 +58,9 @@ TEST(TestVec, TestConvert) {
|
||||
auto actual_vec = at::vec::convert<float>(a_vec);
|
||||
auto expected_vec = b_vec;
|
||||
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
for (int i = 0; i < at::vec::Vectorized<int>::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestVec, TestClampMin) {
|
||||
@ -81,7 +72,9 @@ TEST(TestVec, TestClampMin) {
|
||||
Vec actual_vec = at::vec::clamp_min(a_vec, min_vec);
|
||||
Vec expected_vec = min_vec;
|
||||
|
||||
ExpectVecEqual(expected_vec, actual_vec);
|
||||
for (int i = 0; i < Vec::size(); i++) {
|
||||
EXPECT_EQ(expected_vec[i], actual_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aot_inductor
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
set(AOT_INDUCTOR_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_inference)
|
||||
|
||||
# Build custom TorchScript op for AOTInductor
|
||||
@ -7,12 +8,27 @@ set_target_properties(aoti_custom_class PROPERTIES
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_CUDA)
|
||||
elseif(USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
|
||||
endif()
|
||||
|
||||
# Link against LibTorch
|
||||
target_link_libraries(aoti_custom_class torch)
|
||||
|
||||
# the custom command that generates the TorchScript module
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
# This script requires the torch package to be installed.
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
|
||||
)
|
||||
add_custom_target(aoti_script_model ALL
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(aoti_script_model aoti_custom_class)
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(INDUCTOR_TEST_SRCS
|
||||
${AOT_INDUCTOR_TEST_ROOT}/test.cpp
|
||||
@ -21,12 +37,23 @@ set(INDUCTOR_TEST_SRCS
|
||||
add_executable(test_aoti_inference
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${INDUCTOR_TEST_SRCS}
|
||||
data.pt
|
||||
script_data.pt
|
||||
script_model_cpu.pt
|
||||
script_model_cuda.pt
|
||||
)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class)
|
||||
add_dependencies(test_aoti_inference aoti_custom_class aoti_script_model)
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_aoti_inference PRIVATE USE_GTEST)
|
||||
|
||||
# Define a custom command to generate the library
|
||||
add_custom_command(
|
||||
OUTPUT data.pt
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
)
|
||||
|
||||
target_link_libraries(test_aoti_inference PRIVATE
|
||||
torch
|
||||
gtest_main
|
||||
@ -44,10 +71,6 @@ target_compile_definitions(test_aoti_inference PRIVATE
|
||||
CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-but-set-variable)
|
||||
target_compile_options_if_supported(test_aoti_inference -Wno-unused-function)
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS test_aoti_inference DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
|
||||
@ -2,9 +2,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
@ -30,64 +28,6 @@
|
||||
|
||||
namespace {
|
||||
|
||||
// Function to check if test data files exist and are valid
|
||||
bool testDataFilesExist() {
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
std::array<std::string, 4> required_files = {
|
||||
"data.pt",
|
||||
"script_data.pt",
|
||||
"script_model_cpu.pt",
|
||||
"script_model_cuda.pt"};
|
||||
|
||||
for (const auto& filename : required_files) {
|
||||
std::string filepath = bindir + "/" + filename;
|
||||
std::ifstream file(filepath);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to ensure test data files are generated at runtime
|
||||
void ensureTestDataGenerated() {
|
||||
static std::once_flag generated_flag;
|
||||
std::call_once(generated_flag, []() {
|
||||
// Only generate if files don't exist or are placeholders
|
||||
if (testDataFilesExist()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
|
||||
|
||||
// Calculate path to source directory: build/test_aoti_inference -> build ->
|
||||
// pytorch
|
||||
std::string pytorch_root = bindir.substr(0, bindir.find_last_of("/"));
|
||||
pytorch_root = pytorch_root.substr(0, pytorch_root.find_last_of("/"));
|
||||
std::string source_dir = pytorch_root + "/test/cpp/aoti_inference";
|
||||
|
||||
// Generate test data files (data.pt, etc.) by running test.py directly
|
||||
std::string test_script = source_dir + "/test.py";
|
||||
std::string test_data_cmd = "cd " + bindir + " && python " + test_script;
|
||||
std::cout << "Generating test data: " << test_data_cmd << std::endl;
|
||||
int result1 = std::system(test_data_cmd.c_str());
|
||||
if (result1 != 0) {
|
||||
std::cerr << "Warning: Test data generation failed with code " << result1
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Generate model files (script_*.pt) by running compile_model.py directly
|
||||
std::string compile_script = source_dir + "/compile_model.py";
|
||||
std::string models_cmd = "cd " + bindir + " && python " + compile_script;
|
||||
std::cout << "Generating model files: " << models_cmd << std::endl;
|
||||
int result2 = std::system(models_cmd.c_str());
|
||||
if (result2 != 0) {
|
||||
std::cerr << "Warning: Model generation failed with code " << result2
|
||||
<< std::endl;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, at::Tensor> derefTensorConstantMap(
|
||||
torch::inductor::TensorConstantMap tensor_constant_map) {
|
||||
std::unordered_map<std::string, at::Tensor> ret;
|
||||
@ -915,6 +855,7 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
void test_cuda_alloc_test() {
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
@ -954,8 +895,8 @@ void test_cuda_alloc_test() {
|
||||
runner->run(data_loader.attr(inputs_attr.c_str()).toTensorList().vec());
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
class ThreadPool {
|
||||
private:
|
||||
struct Task {
|
||||
@ -1096,96 +1037,86 @@ void test_multi_cuda_streams(const std::string& device) {
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], all_outputs[i][0]));
|
||||
}
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
#endif // USE_CUDA || USE_ROCM
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
// Test fixture that ensures test data is generated once for all tests
|
||||
class AotInductorTest : public ::testing::Test {
|
||||
public:
|
||||
// This runs once before all tests in this test suite
|
||||
static void SetUpTestSuite() {
|
||||
ensureTestDataGenerated();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AotInductorTest, BasicTestCpu) {
|
||||
TEST(AotInductorTest, BasicTestCpu) {
|
||||
test_aoti("cpu", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, BasicScriptTestCpu) {
|
||||
TEST(AotInductorTest, BasicScriptTestCpu) {
|
||||
test_aoti_script("cpu");
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCpu) {
|
||||
test_aoti_package_loader("cpu", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
TEST(AotInductorTest, ExtractConstantsMapCpu) {
|
||||
test_aoti_extract_constants_map("cpu");
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST_F(AotInductorTest, BasicTestCuda) {
|
||||
TEST(AotInductorTest, BasicTestCuda) {
|
||||
test_aoti("cuda", true);
|
||||
test_aoti("cuda", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, BasicScriptTestCuda) {
|
||||
TEST(AotInductorTest, BasicScriptTestCuda) {
|
||||
test_aoti_script("cuda");
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestCuda) {
|
||||
test_aoti_package_loader("cuda", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
TEST(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
|
||||
test_aoti_package_loader_multi_gpu("cuda", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
TEST(AotInductorTest, UpdateUserManagedConstantsCuda) {
|
||||
test_aoti_user_managed_buffer();
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
TEST(AotInductorTest, RuntimeUpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", true);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, UpdateConstantsCuda) {
|
||||
TEST(AotInductorTest, UpdateConstantsCuda) {
|
||||
test_aoti_constants_update("cuda", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
TEST(AotInductorTest, ExtractConstantsMapCuda) {
|
||||
test_aoti_extract_constants_map("cuda");
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
TEST(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", true);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsCuda) {
|
||||
test_aoti_double_buffering("cuda", false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
|
||||
test_aoti_double_buffering_with_tensor_constants();
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferCuda) {
|
||||
test_aoti_free_buffer(false);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
|
||||
test_aoti_free_buffer(true);
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, MultiStreamTestCuda) {
|
||||
TEST(AotInductorTest, MultiStreamTestCuda) {
|
||||
test_multi_cuda_streams("cuda");
|
||||
}
|
||||
|
||||
TEST_F(AotInductorTest, CudaAllocTestCuda) {
|
||||
TEST(AotInductorTest, CudaAllocTestCuda) {
|
||||
test_cuda_alloc_test();
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -892,16 +892,10 @@ fn(torch.randn(5))
|
||||
os.remove(
|
||||
file_path
|
||||
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
|
||||
orig_maxDiff = unittest.TestCase.maxDiff
|
||||
unittest.TestCase.maxDiff = None
|
||||
try:
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
except Exception:
|
||||
unittest.TestCase.maxDiff = orig_maxDiff
|
||||
raise
|
||||
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
|
||||
empty_line_normalizer(lines),
|
||||
empty_line_normalizer(stderr.decode("utf-8")),
|
||||
)
|
||||
|
||||
@make_settings_test("torch._dynamo.eval_frame")
|
||||
def test_log_traced_frames(self, records):
|
||||
|
||||
@ -122,52 +122,16 @@ def cuda_kernel_profiler(kernel_pattern="flash_attncute"):
|
||||
result["found"] = any(kernel_pattern in name for name in kernel_names)
|
||||
|
||||
|
||||
def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2):
|
||||
def flash_vs_triton(q, k, v, score_mod=None, rtol=5e-3, atol=5e-3):
|
||||
compiled_fn = torch.compile(flex_attention)
|
||||
|
||||
out_ref_fp32 = flex_attention(
|
||||
q.to(torch.float32),
|
||||
k.to(torch.float32),
|
||||
v.to(torch.float32),
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
).to(q.dtype)
|
||||
|
||||
out_flash = compiled_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
kernel_options={"force_flash": True},
|
||||
q, k, v, score_mod=score_mod, kernel_options={"force_flash": True}
|
||||
)
|
||||
out_triton = compiled_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
kernel_options={"force_flash": False},
|
||||
out_no_flash = compiled_fn(
|
||||
q, k, v, score_mod=score_mod, kernel_options={"force_flash": False}
|
||||
)
|
||||
|
||||
assert out_flash.shape == out_ref_fp32.shape == out_triton.shape
|
||||
assert not torch.isnan(out_flash).any()
|
||||
assert not torch.isnan(out_triton).any()
|
||||
assert not torch.isnan(out_ref_fp32).any()
|
||||
assert torch.isfinite(out_flash).all()
|
||||
assert torch.isfinite(out_triton).all()
|
||||
assert torch.isfinite(out_ref_fp32).all()
|
||||
|
||||
fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()
|
||||
|
||||
triton_error = (out_triton - out_ref_fp32).abs().max().item()
|
||||
flash_error = (out_flash - out_ref_fp32).abs().max().item()
|
||||
|
||||
assert flash_error <= rtol * triton_error + fwd_atol, (
|
||||
f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}"
|
||||
)
|
||||
|
||||
return out_flash, out_triton, out_ref_fp32
|
||||
torch.testing.assert_close(out_flash, out_no_flash, rtol=rtol, atol=atol)
|
||||
return out_flash, out_no_flash
|
||||
|
||||
|
||||
def name_fn(score_mod):
|
||||
@ -198,6 +162,26 @@ class TestFlexFlash(InductorTestCase):
|
||||
q, k, v = create_test_tensors(seq_len=seq_len, dtype=dtype, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=_causal)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_force_flash_error_with_block_mask(self, device, dtype):
|
||||
"""Test that force_flash=True raises error when BlockMask is provided."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
# Create a causal block mask
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
|
||||
compiled_fn = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"force_flash=True but flash attention cannot be used.*BlockMask.*not supported",
|
||||
):
|
||||
compiled_fn(
|
||||
q, k, v, block_mask=block_mask, kernel_options={"force_flash": True}
|
||||
)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_kernel_called(self, device, dtype):
|
||||
"""Test that flash attention kernel is actually called when force_flash=True."""
|
||||
@ -273,6 +257,7 @@ class TestFlexFlash(InductorTestCase):
|
||||
"""Test that force_flash=True raises error when tensor requires gradients."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
# Create a score mod with requires_grad tensor
|
||||
bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
def score_mod_with_grad(score, b, h, q_idx, kv_idx):
|
||||
@ -291,108 +276,6 @@ class TestFlexFlash(InductorTestCase):
|
||||
kernel_options={"force_flash": True},
|
||||
)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_block_mask(self, device, dtype):
|
||||
"""Test flash attention with block mask and mask_mod."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_block_mask_with_score_mod(self, device, dtype):
|
||||
"""Test flash attention with both block mask and score_mod."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_mask_mod_buffer(self, device, dtype):
|
||||
"""Test flash attention with mask_mod that loads from buffer."""
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=2, num_heads=4, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def custom_mask(b, h, q_idx, kv_idx):
|
||||
bias_value = mask_bias[h]
|
||||
return (q_idx >= kv_idx) | (bias_value > 0)
|
||||
|
||||
block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_mask_mod_with_dual_buffers(self, device, dtype):
|
||||
"""Mask modifier should support multiple captured buffers."""
|
||||
batch_size, num_heads, seq_len = 2, 4, 512
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2
|
||||
batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.2
|
||||
|
||||
def dual_buffer_mask(b, h, q_idx, kv_idx):
|
||||
head_term = head_bias[h]
|
||||
batch_term = batch_bias[b]
|
||||
causal = q_idx >= kv_idx
|
||||
bias_cond = (head_term + batch_term).to(torch.float32) > 0
|
||||
return causal | bias_cond
|
||||
|
||||
block_mask = create_block_mask(
|
||||
dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device
|
||||
)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_score_mod_with_many_buffer_indexing(self, device, dtype):
|
||||
batch_size, num_heads, seq_len = 2, 4, 512
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.15
|
||||
query_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05
|
||||
kv_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05
|
||||
batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def complex_score(score, b, h, q_idx, kv_idx):
|
||||
head_term = head_bias[h]
|
||||
query_term = query_scale[q_idx]
|
||||
kv_term = kv_scale[kv_idx]
|
||||
batch_term = batch_bias[b]
|
||||
return score + head_term + query_term - kv_term + batch_term
|
||||
|
||||
flash_vs_triton(q, k, v, score_mod=complex_score)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_score_and_mask_buffers(self, device, dtype):
|
||||
"""Test flash attention with both score_mod and mask_mod using buffers."""
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=2, num_heads=4, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
score_bias = torch.randn(4, device=device, dtype=dtype) * 0.2
|
||||
mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def score_with_buffer(score, b, h, q_idx, kv_idx):
|
||||
return score + score_bias[h]
|
||||
|
||||
def mask_with_buffer(b, h, q_idx, kv_idx):
|
||||
bias_value = mask_bias[h]
|
||||
return (q_idx >= kv_idx) | (bias_value > 0)
|
||||
|
||||
block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda")
|
||||
|
||||
|
||||
@ -529,7 +529,7 @@ class TestProfiler(TestCase):
|
||||
found_mm = True
|
||||
if "gemm" in e.name.lower() or "Cijk" in e.name:
|
||||
found_gemm = True
|
||||
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
|
||||
if "memcpy" in e.name.lower():
|
||||
found_memcpy = True
|
||||
if use_cuda:
|
||||
self.assertTrue(found_gemm)
|
||||
|
||||
@ -445,7 +445,7 @@ use_numpy_random_stream = False
|
||||
enable_cpp_guard_manager = True
|
||||
|
||||
# Use C++ guard manager for symbolic shapes
|
||||
enable_cpp_symbolic_shape_guards = not is_fbcode()
|
||||
enable_cpp_symbolic_shape_guards = False
|
||||
|
||||
# Enable tracing through contextlib.contextmanager
|
||||
enable_trace_contextlib = True
|
||||
|
||||
@ -65,10 +65,6 @@ class CuteDSLSubgraphInfo:
|
||||
body: IndentedBuffer
|
||||
template_mask: Optional[str] = None
|
||||
template_out: Optional[str] = None
|
||||
cse: Optional[CSE[Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.only_copy_if_non_none_fields = ("cse",)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -195,15 +191,10 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
cse=None,
|
||||
)
|
||||
|
||||
subgraph = self.subgraph_bodies[body_name]
|
||||
for key, value in subgraph.to_dict().items():
|
||||
if value is None and key in getattr(
|
||||
subgraph, "only_copy_if_non_none_fields", ()
|
||||
):
|
||||
continue
|
||||
setattr(self, key, value)
|
||||
|
||||
try:
|
||||
@ -221,17 +212,15 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
setattr(self, key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False):
|
||||
def create_subgraph_body(self, body_name: str):
|
||||
"""Create a new subgraph body for template processing."""
|
||||
assert body_name not in self.subgraph_bodies, (
|
||||
f"Subgraph body '{body_name}' already exists"
|
||||
)
|
||||
new_cse = self.cse.clone() if clear_cse else None
|
||||
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
|
||||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
cse=new_cse,
|
||||
)
|
||||
with self.set_subgraph_body(body_name):
|
||||
yield
|
||||
@ -305,8 +294,7 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
|
||||
# Register the hook and return placeholder
|
||||
placeholder = "<UNPACK_BUFFERS>"
|
||||
# TODO: I think double invoking is fine for this specific hook
|
||||
# assert placeholder not in self.render_hooks
|
||||
assert placeholder not in self.render_hooks
|
||||
self.render_hooks[placeholder] = hook
|
||||
return placeholder
|
||||
|
||||
@ -342,7 +330,7 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
|
||||
num += 1
|
||||
|
||||
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True):
|
||||
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
|
||||
subgraph = self._get_subgraph(subgraph_number)
|
||||
modification_handler = ModificationWrapperCuteDSL(
|
||||
self, subgraph_number, fixed_inputs, mask
|
||||
@ -441,20 +429,40 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
# val_frag[0] = tensor[index]
|
||||
# result = val_frag.load()
|
||||
|
||||
index_frag = self.kernel.cse.newvar(dtype=torch.int32)
|
||||
self.kernel.body.writeline(
|
||||
f"{index_frag} = cute.make_fragment(1, cutlass.Int32)"
|
||||
)
|
||||
self.kernel.body.writeline(f"{index_frag}.store({index_str})")
|
||||
|
||||
val_frag = self.kernel.cse.newvar(dtype=var_dtype)
|
||||
self.kernel.body.writeline(
|
||||
f"{val_frag} = cute.make_fragment(1, {cute_dtype})"
|
||||
index_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
"cute.make_fragment(1, cutlass.Int32)",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
index_var = self.kernel.cse.newvar(dtype=torch.int32)
|
||||
self.kernel.body.writeline(f"{index_var} = {index_frag}[0]")
|
||||
self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{index_var}])")
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}.store({index_str})",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
val_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"cute.make_fragment(1, {cute_dtype})",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
index_var = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}[0]",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{val_frag}[0] = ({var}[{index_var}])",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
final_expr = f"{val_frag}.load()"
|
||||
|
||||
|
||||
@ -193,6 +193,24 @@ def flex_attention(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
(
|
||||
query,
|
||||
@ -222,30 +240,6 @@ def flex_attention(
|
||||
]
|
||||
)
|
||||
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
mask_graph=mask_graph,
|
||||
)
|
||||
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
|
||||
|
||||
|
||||
@ -56,8 +56,10 @@ def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
|
||||
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
|
||||
|
||||
|
||||
def is_trivial_mask_graph(graph_module: GraphModule) -> bool:
|
||||
"""Mask graph is trivial when it only gates via the default full op."""
|
||||
def is_trivial_graph(
|
||||
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
|
||||
):
|
||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
placeholders = [n for n in nodes if n.op == "placeholder"]
|
||||
@ -65,16 +67,14 @@ def is_trivial_mask_graph(graph_module: GraphModule) -> bool:
|
||||
assert len(output) == 1, "Got graph w/ multiple outputs"
|
||||
output_val = output[0].args[0]
|
||||
|
||||
if is_score_graph:
|
||||
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
|
||||
return False
|
||||
return True # party on garth
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
return len(placeholders) == 4 and output_val.target == torch.ops.aten.full.default
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _supports_nontrivial_mask_graphs() -> bool:
|
||||
"""Currently only supported on Hopper (SM90) GPUs."""
|
||||
return torch.cuda.get_device_capability()[0] == 9
|
||||
|
||||
|
||||
def _can_use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
|
||||
) -> tuple[bool, str]:
|
||||
@ -91,15 +91,32 @@ def _can_use_flex_flash_attention(
|
||||
False,
|
||||
"Input buffers require gradients (not supported by flash attention)",
|
||||
)
|
||||
mask_trivial = is_trivial_mask_graph(mask_graph.graph_module)
|
||||
|
||||
if mask_trivial:
|
||||
return True, ""
|
||||
score_trivial = is_trivial_graph(
|
||||
subgraph.graph_module,
|
||||
is_score_graph=True,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
mask_trivial = is_trivial_graph(
|
||||
mask_graph.graph_module,
|
||||
is_score_graph=False,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
|
||||
if not _supports_nontrivial_mask_graphs():
|
||||
if not score_trivial and not mask_trivial:
|
||||
return (
|
||||
False,
|
||||
"NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention",
|
||||
"Both score and mask graphs are too complex for flash attention (require simple operations only)",
|
||||
)
|
||||
elif not score_trivial:
|
||||
return (
|
||||
False,
|
||||
"Score modification captured tensors that require gradients (not supported by flash attention)",
|
||||
)
|
||||
elif not mask_trivial:
|
||||
return (
|
||||
False,
|
||||
"A non None BlockMask was passed to flex attention (not supported by flash attention yet)",
|
||||
)
|
||||
|
||||
return True, ""
|
||||
@ -137,11 +154,6 @@ def create_flex_flash_attention_kernel(
|
||||
mask_graph_buffer: SubgraphResults,
|
||||
score_mod_other_buffers: list[TensorBox],
|
||||
mask_mod_other_buffers: list[TensorBox],
|
||||
kv_num_blocks: TensorBox | None,
|
||||
kv_indices: TensorBox | None,
|
||||
full_kv_num_blocks: TensorBox | None,
|
||||
full_kv_indices: TensorBox | None,
|
||||
mask_graph: Subgraph,
|
||||
) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]:
|
||||
"""Create a flex flash attention kernel using CuteDSL template."""
|
||||
if not ensure_flash_available():
|
||||
@ -181,34 +193,17 @@ def create_flex_flash_attention_kernel(
|
||||
stride=[sympy.sympify(s) for s in output.get_stride()],
|
||||
)
|
||||
|
||||
# Used to check if we can skip block sparse impl
|
||||
mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module)
|
||||
|
||||
needs_block_mask = not mask_graph_is_trivial
|
||||
has_full_blocks = full_kv_num_blocks is not None
|
||||
|
||||
choices: list[Any] = []
|
||||
causal = kernel_options.get("causal", False)
|
||||
assert flash_attention_cutedsl_template is not None
|
||||
|
||||
input_nodes = [query, key, value, lse]
|
||||
if has_full_blocks:
|
||||
input_nodes.extend(
|
||||
[kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices]
|
||||
)
|
||||
|
||||
if needs_block_mask and not has_full_blocks:
|
||||
raise NotImplementedError(
|
||||
"Flash attention with block mask but without full blocks is not supported yet"
|
||||
)
|
||||
|
||||
error = flash_attention_cutedsl_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
input_nodes=[query, key, value, lse],
|
||||
layout=output_layout,
|
||||
mutated_inputs=[lse],
|
||||
subgraphs=[subgraph_buffer, mask_graph_buffer],
|
||||
SM_SCALE=scale,
|
||||
NEEDS_BLOCK_MASK=needs_block_mask,
|
||||
CAUSAL=causal,
|
||||
)
|
||||
|
||||
if error or not choices:
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
{% if NEEDS_BLOCK_MASK %}
|
||||
{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
{% else %}
|
||||
|
||||
{{def_kernel("Q", "K", "V", "LOGSUMEXP")}}
|
||||
{% endif %}
|
||||
from flash_attn.cute.interface import _flash_attn_fwd
|
||||
from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch
|
||||
|
||||
# Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D)
|
||||
q_transposed = Q.transpose(1, 2)
|
||||
@ -30,25 +26,6 @@
|
||||
output = {{get_output()}}
|
||||
output_transposed = output.transpose(1, 2)
|
||||
|
||||
{% if NEEDS_BLOCK_MASK %}
|
||||
@cute.jit
|
||||
def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
|
||||
{{unpack_buffers("aux_tensors", indent_width=8)}}
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name="mask_mod_output",
|
||||
b="b_idx",
|
||||
h="h_idx",
|
||||
m="q_idx",
|
||||
n="kv_idx",
|
||||
) | indent_except_first(2) }}
|
||||
return mask_mod_output
|
||||
block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX)
|
||||
{% else %}
|
||||
block_sparse_tensors = None
|
||||
mask_mod = None
|
||||
{% endif %}
|
||||
|
||||
# Collect any additional tensor buffers that were added during modifications
|
||||
{% set tensor_buffers = get_tensor_buffers() -%}
|
||||
{% if tensor_buffers -%}
|
||||
@ -64,11 +41,10 @@
|
||||
k_transposed,
|
||||
v_transposed,
|
||||
softmax_scale={{SM_SCALE}},
|
||||
causal={{CAUSAL}},
|
||||
return_lse=True,
|
||||
score_mod=score_mod,
|
||||
mask_mod=mask_mod,
|
||||
out=output_transposed,
|
||||
lse=LOGSUMEXP,
|
||||
block_sparse_tensors=block_sparse_tensors,
|
||||
aux_tensors=buffers
|
||||
)
|
||||
)
|
||||
@ -409,10 +409,9 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
ancestors: OrderedSet[str]
|
||||
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
|
||||
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
|
||||
last_usage: OrderedSet[str]
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
@ -421,24 +420,22 @@ class BaseSchedulerNode:
|
||||
min_order: int
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
mutation_renames: dict[str, str]
|
||||
node: Optional[ir.Operation]
|
||||
outputs: list[SchedulerBuffer]
|
||||
outputs_by_name: dict[str, SchedulerBuffer]
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler = scheduler
|
||||
self.debug_device_str = lambda *args, **kwargs: []
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node = node
|
||||
self.ancestors = OrderedSet()
|
||||
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs = [
|
||||
self.outputs: list[SchedulerBuffer] = [
|
||||
SchedulerBuffer(
|
||||
scheduler=self.scheduler,
|
||||
node=output,
|
||||
@ -446,14 +443,16 @@ class BaseSchedulerNode:
|
||||
)
|
||||
for output in node.get_outputs()
|
||||
]
|
||||
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
|
||||
self.outputs_by_name: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for buf in self.outputs
|
||||
}
|
||||
|
||||
# mutation_renames for the current node. Due to potential
|
||||
# more mutations happening later, this can be different
|
||||
# to Scheduler.mutation_renames. Also this dict should be small
|
||||
# since only mutation information relevant to the deps for this
|
||||
# node is stored here.
|
||||
self.mutation_renames = {}
|
||||
self.mutation_renames: dict[str, str] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(name={self.get_name()!r})"
|
||||
@ -2436,34 +2435,6 @@ def pick_loop_order(
|
||||
return order
|
||||
|
||||
|
||||
def _replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeUser:
|
||||
node: Union[BaseSchedulerNode, OutputNode]
|
||||
@ -3365,6 +3336,33 @@ class Scheduler:
|
||||
will force completion of compilation and benchmarking.
|
||||
"""
|
||||
|
||||
def replace_operation_buffer(
|
||||
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
|
||||
) -> None:
|
||||
replaced_buf_name = new_node.get_name()
|
||||
orig_buf_name = orig_node.get_name()
|
||||
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
|
||||
|
||||
replaced_op_name = new_node.get_operation_name()
|
||||
orig_op_name = orig_node.get_operation_name()
|
||||
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_buf_name]
|
||||
new_node.name = orig_buf_name
|
||||
|
||||
del V.graph.name_to_op[replaced_op_name]
|
||||
new_node.operation_name = orig_op_name
|
||||
|
||||
orig = V.graph.buffers.index(orig_node)
|
||||
V.graph.buffers.remove(new_node)
|
||||
V.graph.buffers[orig] = new_node
|
||||
V.graph.name_to_buffer[orig_buf_name] = new_node
|
||||
|
||||
orig = V.graph.operations.index(orig_node)
|
||||
V.graph.operations.remove(new_node)
|
||||
V.graph.operations[orig] = new_node
|
||||
V.graph.name_to_op[orig_op_name] = new_node
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, ir.MultiTemplateBuffer
|
||||
@ -3418,47 +3416,40 @@ class Scheduler:
|
||||
assign_origin_node(out_tensorbox, multi_node.origin_node)
|
||||
|
||||
out_buffer.layout = multi_node.layout
|
||||
self._replace_node(out_buffer, multi_node, i, node)
|
||||
replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
|
||||
def _replace_node(
|
||||
self,
|
||||
out_buffer: ir.OperationBuffer,
|
||||
multi_node: ir.MultiTemplateBuffer,
|
||||
i: int,
|
||||
node: SchedulerNode,
|
||||
) -> None:
|
||||
_replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(
|
||||
node.read_writes.reads, node.unmet_dependencies
|
||||
):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
|
||||
# We need to reflect the mutation renames that were recorded in the original node
|
||||
mutation_renames = {}
|
||||
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
|
||||
if real_name := self.mutation_real_name.get(dep.name, None):
|
||||
mutation_renames[real_name] = dep.name
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
|
||||
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
|
||||
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
|
||||
new_scheduler_node.unmet_dependencies = rename_deps(
|
||||
new_scheduler_node.unmet_dependencies
|
||||
)
|
||||
new_scheduler_node.read_writes.reads = rename_deps(
|
||||
new_scheduler_node.read_writes.reads
|
||||
)
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
|
||||
for new_out, old_out in zip(
|
||||
new_scheduler_node.get_outputs(), node.get_outputs()
|
||||
):
|
||||
self.name_to_buf[old_out.get_name()] = new_out
|
||||
new_out.users = old_out.users
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
|
||||
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
|
||||
return any(
|
||||
|
||||
Reference in New Issue
Block a user