mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
Compare commits
2 Commits
bf/bug-sta
...
upload-tes
| Author | SHA1 | Date | |
|---|---|---|---|
| 03cddd3c9c | |||
| 190797db14 |
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.0",
|
||||
"13.0": "13.0.2",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,7 +8,6 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
3
.github/workflows/trunk.yml
vendored
3
.github/workflows/trunk.yml
vendored
@ -204,7 +204,6 @@ jobs:
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -222,7 +221,7 @@ jobs:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
secrets: inherit
|
||||
|
||||
inductor-build:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -211,6 +211,7 @@ exclude_patterns = [
|
||||
'**/*pb.h',
|
||||
'**/*inl.h',
|
||||
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
||||
'aten/src/ATen/cpu/Utils.cpp',
|
||||
'aten/src/ATen/cpu/vml.h',
|
||||
'aten/src/ATen/CPUFixedAllocator.h',
|
||||
'aten/src/ATen/Parallel*.h',
|
||||
@ -229,6 +230,8 @@ exclude_patterns = [
|
||||
'c10/util/win32-headers.h',
|
||||
'c10/test/**/*.h',
|
||||
'third_party/**/*',
|
||||
'torch/csrc/api/include/torch/nn/modules/common.h',
|
||||
'torch/csrc/api/include/torch/linalg.h',
|
||||
'torch/csrc/autograd/generated/**',
|
||||
'torch/csrc/distributed/**/*.cu',
|
||||
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
||||
@ -240,6 +243,7 @@ exclude_patterns = [
|
||||
'torch/csrc/utils/generated_serialization_types.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||
'aten/src/ATen/ExpandBase.h',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -39,6 +41,16 @@ namespace at {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -191,37 +191,22 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename to_type>
|
||||
inline void convertFromBf16Impl(
|
||||
const c10::BFloat16* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
|
||||
float tmpF;
|
||||
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
|
||||
dst[i] = static_cast<to_type>(tmpF);
|
||||
}
|
||||
}
|
||||
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
|
||||
template <> \
|
||||
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
|
||||
return convertFromBf16Impl<to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int32_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int64_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float)
|
||||
CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
@ -262,6 +247,8 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -92,8 +92,7 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
|
||||
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
|
||||
if (dtype == kBFloat16) {
|
||||
auto norm_val = norm.to<float>();
|
||||
float beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<float>(norm_val);
|
||||
@ -102,9 +101,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
const auto zero_vec = Vectorized<float>(0);
|
||||
const auto pos_1_vec = Vectorized<float>(1);
|
||||
cpu_kernel_vec(iter,
|
||||
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
|
||||
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
|
||||
const auto x = float(input) - float(target);
|
||||
if (x <= -beta) {
|
||||
if (x <= -beta){
|
||||
return -norm_val * float(grad_output);
|
||||
}else if (x >= beta){
|
||||
return norm_val * float(grad_output);
|
||||
@ -113,14 +112,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
}
|
||||
},
|
||||
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
|
||||
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
|
||||
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
|
||||
// using two blendv calls to simulate the 3 cases
|
||||
// 1 if x >= beta
|
||||
// -1 if x <= -beta
|
||||
// x / beta if |x| < beta
|
||||
auto [input0, input1] = convert_to_float(input);
|
||||
auto [target0, target1] = convert_to_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
|
||||
auto [input0, input1] = convert_bfloat16_float(input);
|
||||
auto [target0, target1] = convert_bfloat16_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
|
||||
auto x = input0 - target0;
|
||||
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
|
||||
neg_1_vec, pos_1_vec, x > zero_vec);
|
||||
@ -136,10 +135,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
output = Vectorized<float>::blendv(
|
||||
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
|
||||
input1 = norm_val_vec * output * grad_output1;
|
||||
return convert_from_float<scalar_t>(input0, input1);
|
||||
return convert_float_bfloat16(input0, input1);
|
||||
}
|
||||
);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
|
||||
@ -59,24 +59,6 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -785,6 +767,33 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -797,7 +806,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -814,61 +823,23 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -890,65 +861,24 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -970,62 +900,24 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -1105,47 +997,26 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1160,30 +1031,11 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1308,20 +1160,17 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return true;
|
||||
|
||||
dnnl_dims_t blocks = {0};
|
||||
std::array<int, DNNL_MAX_NDIMS> perm = {0};
|
||||
int perm[DNNL_MAX_NDIMS] = {0};
|
||||
for (int d = 0; d < md_ndims; ++d) {
|
||||
// no strides check needed for empty tensor
|
||||
if ((*md_padded_dims)[d] == 0)
|
||||
if (md_padded_dims[d] == nullptr)
|
||||
return true;
|
||||
|
||||
// no strides verification for runtime dims
|
||||
@ -178,15 +178,14 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
|
||||
// A custom comparator to yield linear order on perm
|
||||
auto idx_sorter = [&](const int a, const int b) -> bool {
|
||||
if (strides[a] == strides[b] &&
|
||||
(*md_padded_dims)[a] == (*md_padded_dims)[b])
|
||||
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
|
||||
return a < b;
|
||||
else if (strides[a] == strides[b])
|
||||
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
|
||||
return md_padded_dims[a] < md_padded_dims[b];
|
||||
else
|
||||
return strides[a] < strides[b];
|
||||
};
|
||||
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
|
||||
std::sort(perm, perm + md_ndims, idx_sorter);
|
||||
|
||||
auto min_stride = block_size;
|
||||
for (int idx = 0; idx < md_ndims; ++idx) {
|
||||
@ -200,10 +199,9 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return false;
|
||||
|
||||
// update min_stride for next iteration
|
||||
const auto padded_dim = (*md_padded_dims)[d];
|
||||
const auto padded_dim = *md_padded_dims[d];
|
||||
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
onValue:-1.0f
|
||||
offValue:0.0f
|
||||
name:nil];
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
||||
if (isWeightsArrayValid) {
|
||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||
secondaryTensor:weightTensor
|
||||
@ -705,7 +705,6 @@ static void smooth_l1_loss_template(const Tensor& input,
|
||||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||
TORCH_CHECK(input.is_mps());
|
||||
TORCH_CHECK(target.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||
return;
|
||||
@ -772,7 +771,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
||||
// xn - yn
|
||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:targetTensor
|
||||
@ -798,8 +797,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
name:@"lossTensor"];
|
||||
MPSGraphTensor* outputTensor = lossTensor;
|
||||
if (reduction == Reduction::Mean) {
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||
dataType:[lossTensor dataType]];
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||
|
||||
@ -1,63 +0,0 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
// p2pAccessEnabled_ is a flattened 2D matrix of size [num_devices x
|
||||
// num_devices].
|
||||
// Each element represents whether device[i] can access device[j]:
|
||||
// 1 -> access allowed
|
||||
// 0 -> access not allowed
|
||||
// -1 -> unknown (not yet queried)
|
||||
static std::vector<int8_t> p2pAccessEnabled_;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Initializes the peer-to-peer (P2P) access capability cache.
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices) {
|
||||
// By default, each device can always access itself (diagonal entries = 1).
|
||||
// For simplicity, all entries are initialized to -1 except the diagonal.
|
||||
static bool once [[maybe_unused]] = [num_devices]() {
|
||||
p2pAccessEnabled_.clear();
|
||||
p2pAccessEnabled_.resize(num_devices * num_devices, -1);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
p2pAccessEnabled_[i * num_devices + i] = 1;
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
|
||||
|
||||
check_device_index(dev);
|
||||
check_device_index(dev_to_access);
|
||||
|
||||
auto& cache =
|
||||
p2pAccessEnabled_[dev * c10::xpu::device_count() + dev_to_access];
|
||||
|
||||
if (cache != -1) {
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
// Query the hardware to determine if P2P access is supported
|
||||
cache = static_cast<int8_t>(
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_can_access_peer(
|
||||
c10::xpu::get_raw_device(dev_to_access),
|
||||
sycl::ext::oneapi::peer_access::access_supported));
|
||||
|
||||
if (cache) {
|
||||
XPUCachingAllocator::enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
||||
@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at::xpu {
|
||||
namespace detail {
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices);
|
||||
} // namespace detail
|
||||
|
||||
TORCH_XPU_API bool get_p2p_access(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
} // namespace at::xpu
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/PinnedMemoryAllocator.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
#include <ATen/xpu/XPUDevice.h>
|
||||
@ -13,7 +12,6 @@ void XPUHooks::init() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
|
||||
const auto device_count = c10::xpu::device_count_ensure_non_zero();
|
||||
c10::xpu::XPUCachingAllocator::init(device_count);
|
||||
at::xpu::detail::init_p2p_access_cache(device_count);
|
||||
}
|
||||
|
||||
bool XPUHooks::hasXPU() const {
|
||||
|
||||
@ -929,7 +929,6 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/utils.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/dynamo/stackref_bridge.c",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
|
||||
@ -21,20 +21,13 @@ using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
struct Block;
|
||||
typedef bool (*Comparison)(const Block*, const Block*);
|
||||
bool BlockComparatorSize(const Block* a, const Block* b);
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b);
|
||||
|
||||
struct BlockPool {
|
||||
BlockPool(bool small)
|
||||
: blocks(BlockComparatorSize),
|
||||
unmapped(BlockComparatorAddress),
|
||||
is_small(small) {}
|
||||
BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
|
||||
std::set<Block*, Comparison> blocks;
|
||||
std::set<Block*, Comparison> unmapped;
|
||||
const bool is_small;
|
||||
};
|
||||
|
||||
struct ExpandableSegment;
|
||||
|
||||
struct Block {
|
||||
DeviceIndex device;
|
||||
sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
|
||||
@ -44,11 +37,9 @@ struct Block {
|
||||
BlockPool* pool{nullptr}; // owning memory pool
|
||||
void* ptr{nullptr}; // memory address
|
||||
bool allocated{false}; // in-use flag
|
||||
bool mapped{true}; // True if this Block is backed by physical pages
|
||||
Block* prev{nullptr}; // prev block if split from a larger allocation
|
||||
Block* next{nullptr}; // next block if split from a larger allocation
|
||||
int event_count{0}; // number of outstanding XPU events
|
||||
ExpandableSegment* expandable_segment{nullptr}; // owning expandable segment
|
||||
|
||||
Block(
|
||||
DeviceIndex device,
|
||||
@ -75,20 +66,6 @@ struct Block {
|
||||
bool is_split() const {
|
||||
return (prev != nullptr) || (next != nullptr);
|
||||
}
|
||||
|
||||
// Inserts this block between two existing blocks with [before, this, after].
|
||||
void splice(Block* before, Block* after) {
|
||||
if (before) {
|
||||
TORCH_INTERNAL_ASSERT(before->next == after);
|
||||
before->next = this;
|
||||
}
|
||||
prev = before;
|
||||
if (after) {
|
||||
TORCH_INTERNAL_ASSERT(after->prev == before);
|
||||
after->prev = this;
|
||||
}
|
||||
next = after;
|
||||
}
|
||||
};
|
||||
|
||||
bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
@ -103,221 +80,6 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b) {
|
||||
if (a->queue != b->queue) {
|
||||
return reinterpret_cast<uintptr_t>(a->queue) <
|
||||
reinterpret_cast<uintptr_t>(b->queue);
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(a->ptr) <
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
// Represents a contiguous virtual memory segment mapped for allocation.
|
||||
struct SegmentRange {
|
||||
SegmentRange(void* addr, size_t bytes)
|
||||
: ptr(static_cast<char*>(addr)), size(bytes) {}
|
||||
char* ptr; // Starting address of the mapped range.
|
||||
size_t size; // Size in bytes of the mapped range.
|
||||
};
|
||||
|
||||
struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
std::optional<sycl::queue*> queue,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers)
|
||||
: device_(device),
|
||||
queue_(queue),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(segment_size),
|
||||
peers_(std::move(peers)) {
|
||||
const auto device_total =
|
||||
c10::xpu::get_raw_device(device)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
// The extra 1/8 allows flexibility for remapping or moving pages within the
|
||||
// segment when unmapping earlier regions.
|
||||
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
|
||||
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
|
||||
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
|
||||
segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
|
||||
ExpandableSegment(ExpandableSegment&&) = delete;
|
||||
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
|
||||
|
||||
// Maps a virtual memory range to physical memory.
|
||||
SegmentRange map(SegmentRange range) {
|
||||
auto begin = segmentLeft(range.ptr);
|
||||
auto end = segmentRight(range.ptr + range.size);
|
||||
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
|
||||
if (begin == end) {
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Ensure handles_ vector is large enough to hold all segments.
|
||||
if (end > handles_.size()) {
|
||||
handles_.resize(end, std::nullopt);
|
||||
}
|
||||
|
||||
// Allocate and map physical memory for each segment.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
TORCH_INTERNAL_ASSERT(!handles_.at(i));
|
||||
try {
|
||||
// Allocate physical memory for each segment. Construct the physical_mem
|
||||
// in-place to avoid copies.
|
||||
handles_.at(i).emplace(
|
||||
xpu::get_raw_device(device_),
|
||||
xpu::get_device_context(),
|
||||
segment_size_);
|
||||
// Map the allocated physical memory into the virtual address space.
|
||||
handles_.at(i).value().map(
|
||||
ptr_ + i * segment_size_,
|
||||
segment_size_,
|
||||
sycl::ext::oneapi::experimental::address_access_mode::read_write);
|
||||
} catch (const sycl::exception& e) {
|
||||
// Allocation failure: typically sycl::errc::memory_allocation.
|
||||
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
|
||||
// over-subscription).
|
||||
// Note: constructing physical_mem may over-subscribe device memory but
|
||||
// not immediately trigger OOM. The actual OOM can occur during map().
|
||||
// Roll back all segments allocated or mapped in this operation.
|
||||
handles_.at(i) = std::nullopt;
|
||||
for (const auto j : c10::irange(begin, i)) {
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
handles_.at(j) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
return rangeFromHandles(begin, begin);
|
||||
}
|
||||
}
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Unmap a virtual memory range from physical memory.
|
||||
SegmentRange unmap(SegmentRange range) {
|
||||
auto begin = segmentRight(range.ptr);
|
||||
auto end = segmentLeft(range.ptr + range.size);
|
||||
if (begin >= end) {
|
||||
return SegmentRange{range.ptr, 0};
|
||||
}
|
||||
unmapHandles(begin, end);
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Returns the base pointer of the virtual memory segment.
|
||||
char* ptr() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<char*>(ptr_);
|
||||
}
|
||||
|
||||
// Returns the total size of the virtual memory segment.
|
||||
size_t size() const {
|
||||
return max_handles_ * segment_size_;
|
||||
}
|
||||
|
||||
~ExpandableSegment() {
|
||||
forEachAllocatedRange(
|
||||
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
|
||||
sycl::ext::oneapi::experimental::free_virtual_mem(
|
||||
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
private:
|
||||
// Unmaps the physical memory handles in the range [begin, end) from the
|
||||
// segment.
|
||||
void unmapHandles(size_t begin, size_t end) {
|
||||
// Currently, we don't support IPC shared memory with expandable segments.
|
||||
TORCH_INTERNAL_ASSERT(queue_);
|
||||
// As explained in Note [Safe to Free Blocks on BlockPool], additional
|
||||
// synchronization is unnecessary here because the memory is already safe to
|
||||
// release.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
// Note: physical_mem's destructor does NOT automatically unmap any mapped
|
||||
// ranges. Users must explicitly call unmap on all ranges before
|
||||
// destroying the physical_mem object.
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
// Here physical_mem object is being destructed.
|
||||
handles_.at(i) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
}
|
||||
|
||||
// Remove trailing unused handles from the end of handles_.
|
||||
void trimHandles() {
|
||||
while (!handles_.empty() && !handles_.back()) {
|
||||
handles_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
// Iterates over all contiguous ranges of allocated segments in `handles_`,
|
||||
// and invokes the provided function `fn(start, end)` for each range.
|
||||
// Each range is defined as a half-open interval [start, end).
|
||||
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
|
||||
size_t start = 0;
|
||||
for (const auto i : c10::irange(handles_.size())) {
|
||||
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
|
||||
start = i;
|
||||
}
|
||||
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
|
||||
fn(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the number of full segments required to cover `size` bytes.
|
||||
// Rounds up to ensure partial segments are counted.
|
||||
size_t numSegments(size_t size) const {
|
||||
return (size + segment_size_ - 1) / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment that contains the pointer `p`,
|
||||
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
|
||||
// of the segment that includes `p`.
|
||||
size_t segmentLeft(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return offset / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment just *past* the one containing pointer
|
||||
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
|
||||
// bound, useful for [begin, end) style ranges.
|
||||
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
|
||||
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
|
||||
size_t segmentRight(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return numSegments(offset);
|
||||
}
|
||||
|
||||
// Constructs a SegmentRange spanning indices [start, end).
|
||||
SegmentRange rangeFromHandles(size_t begin, size_t end) {
|
||||
return SegmentRange(
|
||||
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
|
||||
}
|
||||
|
||||
c10::DeviceIndex device_{-1};
|
||||
std::optional<sycl::queue*> queue_;
|
||||
// Virtual memory address used for reservation.
|
||||
uintptr_t ptr_{0};
|
||||
// Size of each segment in bytes.
|
||||
size_t segment_size_{0};
|
||||
// Maximum number of segments that can be allocated in this segment.
|
||||
size_t max_handles_{0};
|
||||
// Physical memory handles for the segments.
|
||||
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
|
||||
handles_{};
|
||||
// Peer devices on which this memory could be accessible, reserved.
|
||||
std::vector<c10::DeviceIndex> peers_{};
|
||||
};
|
||||
|
||||
struct AllocParams {
|
||||
AllocParams(
|
||||
DeviceIndex device,
|
||||
@ -363,12 +125,10 @@ class DeviceCachingAllocator {
|
||||
DeviceIndex device_index;
|
||||
size_t allowed_memory_maximum = 0;
|
||||
bool set_fraction = false;
|
||||
std::vector<ExpandableSegment*> expandable_segments;
|
||||
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
|
||||
|
||||
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
|
||||
if (!src || src->allocated || src->event_count > 0 ||
|
||||
!src->stream_uses.empty() || dst->mapped != src->mapped) {
|
||||
!src->stream_uses.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -387,8 +147,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
const size_t subsumed_size = src->size;
|
||||
dst->size += subsumed_size;
|
||||
auto erased =
|
||||
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
|
||||
auto erased = pool.blocks.erase(src);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
|
||||
delete src;
|
||||
|
||||
@ -471,175 +230,12 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
// Finds the first (lowest-address) block in any segment that has sufficient
|
||||
// contiguous free virtual address space to satisfy `size`. The available
|
||||
// space may span multiple adjacent blocks, which can include both free and
|
||||
// unmapped segments.
|
||||
Block* find_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
Block key(device, queue, 0);
|
||||
|
||||
auto allocatable = [](Block* b) {
|
||||
return b && !b->allocated && b->event_count == 0 &&
|
||||
b->stream_uses.empty();
|
||||
};
|
||||
auto has_available_address_space = [&](Block* b) {
|
||||
size_t bytes = 0;
|
||||
while (bytes < size && allocatable(b)) {
|
||||
bytes += b->size;
|
||||
b = b->next;
|
||||
}
|
||||
return bytes >= size;
|
||||
};
|
||||
for (auto it = pool->unmapped.lower_bound(&key);
|
||||
it != pool->unmapped.end() && (*it)->queue == queue;
|
||||
++it) {
|
||||
Block* c = *it;
|
||||
// The unmapped block might have a free mapped block right before it.
|
||||
// By starting from the previous block, we can use both:
|
||||
// [Free Mapped Block] + [Unmapped Block] = More contiguous space
|
||||
if (allocatable(c->prev)) {
|
||||
c = c->prev;
|
||||
}
|
||||
if (has_available_address_space(c)) {
|
||||
return c;
|
||||
}
|
||||
}
|
||||
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
|
||||
expandable_segments.emplace_back(new ExpandableSegment(
|
||||
device, queue, segment_size, devices_with_peer_access));
|
||||
|
||||
ExpandableSegment* es = expandable_segments.back();
|
||||
Block* candidate = new Block(device, queue, es->size(), pool, es->ptr());
|
||||
candidate->mapped = false;
|
||||
candidate->expandable_segment = es;
|
||||
pool->unmapped.insert(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool map_block(Block* to_map, size_t size) {
|
||||
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
|
||||
auto mapped_range =
|
||||
to_map->expandable_segment->map(SegmentRange{to_map->ptr, size});
|
||||
// Failed to map the memory
|
||||
if (mapped_range.size == 0) {
|
||||
return false;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
|
||||
|
||||
BlockPool& pool = *to_map->pool;
|
||||
pool.unmapped.erase(to_map);
|
||||
to_map->mapped = true;
|
||||
|
||||
if (mapped_range.size < to_map->size) {
|
||||
// to_map -> remaining -> to_map->next(?)
|
||||
Block* remaining = new Block(
|
||||
to_map->device,
|
||||
to_map->queue,
|
||||
to_map->size - mapped_range.size,
|
||||
&pool,
|
||||
static_cast<char*>(to_map->ptr) + mapped_range.size);
|
||||
remaining->mapped = false;
|
||||
remaining->expandable_segment = to_map->expandable_segment;
|
||||
remaining->splice(to_map, to_map->next);
|
||||
pool.unmapped.insert(remaining);
|
||||
to_map->size = mapped_range.size;
|
||||
}
|
||||
|
||||
try_merge_blocks(to_map, to_map->prev, pool);
|
||||
try_merge_blocks(to_map, to_map->next, pool);
|
||||
|
||||
pool.blocks.insert(to_map);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(mapped_range.size);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Block* try_allocate_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
// Candidate points to the start of a chain of contiguous blocks with
|
||||
// sufficient virtual address space (>= size). The chain may consist of:
|
||||
// Case 1: [Unmapped Block] -> null
|
||||
// Case 2: [Unmapped Block] -> [Free Mapped Block]
|
||||
// Case 3: [Free Mapped Block] -> [Unmapped Block]
|
||||
Block* candidate = find_expandable_block(device, queue, pool, size);
|
||||
|
||||
// Map first block if unmapped (Case 1 & 2), use std::min to avoid
|
||||
// over-mapping.
|
||||
if (!candidate->mapped &&
|
||||
!map_block(candidate, std::min(candidate->size, size))) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(candidate->mapped);
|
||||
|
||||
// Map additional blocks until we have enough continuous space (Case 3).
|
||||
// Each map_block() call merges newly mapped blocks with adjacent free
|
||||
// blocks
|
||||
while (candidate->size < size) {
|
||||
auto remaining = size - candidate->size;
|
||||
auto new_candidate = candidate->next;
|
||||
// Map only what we need from the `new_candidate` block.
|
||||
if (!map_block(new_candidate, std::min(remaining, new_candidate->size))) {
|
||||
return nullptr;
|
||||
}
|
||||
candidate = new_candidate;
|
||||
}
|
||||
|
||||
// Remove from the free pool; block will be marked as `allocated` in
|
||||
// alloc_found_block()
|
||||
pool->blocks.erase(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool get_free_block(AllocParams& p) {
|
||||
BlockPool& pool = *p.pool;
|
||||
auto it = pool.blocks.lower_bound(&p.search_key);
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
if ((*it)->expandable_segment) {
|
||||
if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
// When expandable segments are enabled, consider both the current block
|
||||
// and any immediately adjacent unmapped region as a single expandable
|
||||
// area. For "best fit" allocation, we use the total expandable size
|
||||
// instead of just the block's current size, so that blocks which can
|
||||
// grow into a larger contiguous range are preferred.
|
||||
auto expandable_size = [](Block* b) {
|
||||
// b->next may belong to pool.unmapped (reserved but not mapped)
|
||||
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
|
||||
};
|
||||
auto next = it;
|
||||
next++;
|
||||
// Looks for the best fit block with expandable size.
|
||||
while ((*it)->expandable_segment && next != pool.blocks.end() &&
|
||||
(*next)->queue == p.queue() &&
|
||||
expandable_size(*next) < expandable_size(*it)) {
|
||||
it = next++;
|
||||
}
|
||||
} else {
|
||||
// Expandable segments were previously enabled, but are now disabled
|
||||
// (e.g. to avoid IPC issues). Skip any expandable blocks and only
|
||||
// find from regular non-expandable segments.
|
||||
do {
|
||||
it++;
|
||||
} while (it != pool.blocks.end() && (*it)->expandable_segment &&
|
||||
(*it)->queue == p.queue());
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
return true;
|
||||
@ -656,10 +252,6 @@ class DeviceCachingAllocator {
|
||||
size >
|
||||
allowed_memory_maximum) {
|
||||
return false;
|
||||
} else if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
p.block =
|
||||
try_allocate_expandable_block(device, p.queue(), p.pool, p.size());
|
||||
return bool(p.block);
|
||||
}
|
||||
void* ptr = sycl::aligned_alloc_device(
|
||||
kDeviceAlignment,
|
||||
@ -673,7 +265,6 @@ class DeviceCachingAllocator {
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -692,27 +283,6 @@ class DeviceCachingAllocator {
|
||||
xpu_events.clear();
|
||||
}
|
||||
|
||||
void release_expandable_segment(Block* block) {
|
||||
// See Note [Safe to Free Blocks on BlockPool], additional synchronization
|
||||
// is unnecessary here because this function is only called by
|
||||
// release_cached_blocks().
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
block->size == block->expandable_segment->size(),
|
||||
"block disagrees with segment");
|
||||
TORCH_INTERNAL_ASSERT(!block->mapped);
|
||||
|
||||
auto it = std::find(
|
||||
expandable_segments.begin(),
|
||||
expandable_segments.end(),
|
||||
block->expandable_segment);
|
||||
TORCH_INTERNAL_ASSERT(it != expandable_segments.end());
|
||||
|
||||
expandable_segments.erase(it);
|
||||
block->pool->unmapped.erase(block);
|
||||
delete block->expandable_segment;
|
||||
delete block;
|
||||
}
|
||||
|
||||
void release_block(Block* block) {
|
||||
/*
|
||||
* Note [Safe to Free Blocks on BlockPool]
|
||||
@ -723,7 +293,6 @@ class DeviceCachingAllocator {
|
||||
* We have to do a device-level synchronization before free these blocks to
|
||||
* guarantee that all kernels can access to the blocks have finished.
|
||||
*/
|
||||
TORCH_INTERNAL_ASSERT(!block->expandable_segment);
|
||||
sycl::free(block->ptr, xpu::get_device_context());
|
||||
auto* pool = block->pool;
|
||||
pool->blocks.erase(block);
|
||||
@ -736,78 +305,13 @@ class DeviceCachingAllocator {
|
||||
delete block;
|
||||
}
|
||||
|
||||
void unmap_block(Block* block) {
|
||||
auto unmapped =
|
||||
block->expandable_segment->unmap(SegmentRange{block->ptr, block->size});
|
||||
if (unmapped.size == 0) {
|
||||
return;
|
||||
}
|
||||
block->pool->blocks.erase(block);
|
||||
|
||||
ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
|
||||
if (before_size > 0) {
|
||||
// If the actual unmapped region starts after block->ptr due to alignment,
|
||||
// the region before unmapped.ptr is still mapped.
|
||||
// [Prev Block?] -> [Before Block] -> [Unmapped Block]
|
||||
Block* before_free = new Block(
|
||||
block->device, block->queue, before_size, block->pool, block->ptr);
|
||||
before_free->expandable_segment = block->expandable_segment;
|
||||
before_free->splice(block->prev, block);
|
||||
block->pool->blocks.insert(before_free);
|
||||
}
|
||||
|
||||
auto after_size = block->size - (before_size + unmapped.size);
|
||||
if (after_size > 0) {
|
||||
// If the actual unmapped region ends before block->ptr + block->size,
|
||||
// the region after (unmapped.ptr + unmapped.size) is still mapped.
|
||||
// [Unmapped Block] -> [After Block] -> [Next Block?]
|
||||
Block* after_free = new Block(
|
||||
block->device,
|
||||
block->queue,
|
||||
after_size,
|
||||
block->pool,
|
||||
unmapped.ptr + unmapped.size);
|
||||
after_free->expandable_segment = block->expandable_segment;
|
||||
after_free->splice(block, block->next);
|
||||
block->pool->blocks.insert(after_free);
|
||||
}
|
||||
|
||||
// [Before Mapped Block?] -> [Unmapped Block] -> [After Mapped Block?]
|
||||
block->ptr = unmapped.ptr;
|
||||
block->size = unmapped.size;
|
||||
block->mapped = false;
|
||||
|
||||
try_merge_blocks(block, block->prev, *block->pool);
|
||||
try_merge_blocks(block, block->next, *block->pool);
|
||||
block->pool->unmapped.insert(block);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].decrease(unmapped.size);
|
||||
});
|
||||
}
|
||||
|
||||
void release_blocks(BlockPool& pool) {
|
||||
std::vector<Block*> to_unmap;
|
||||
// Frees all non-split blocks in the given pool.
|
||||
auto it = pool.blocks.begin();
|
||||
while (it != pool.blocks.end()) {
|
||||
Block* block = *it;
|
||||
++it;
|
||||
if (block->expandable_segment) {
|
||||
// unmap_block() modifies the free pool, so collect items to free first
|
||||
// to avoid iterator invalidation.
|
||||
to_unmap.push_back(block);
|
||||
} else if (!block->prev && !block->next) {
|
||||
release_block(block);
|
||||
}
|
||||
}
|
||||
for (Block* block : to_unmap) {
|
||||
unmap_block(block);
|
||||
// After unmap_block(), expandable segment blocks with no neighbors are
|
||||
// also released.
|
||||
if (!block->prev && !block->next) {
|
||||
release_expandable_segment(block);
|
||||
release_block(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -824,8 +328,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool should_split(const Block* block, size_t size) {
|
||||
size_t remaining = block->size - size;
|
||||
if (block->pool->is_small ||
|
||||
AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
if (block->pool->is_small) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return remaining > kSmallSize;
|
||||
@ -858,7 +361,6 @@ class DeviceCachingAllocator {
|
||||
remaining = block;
|
||||
|
||||
block = new Block(device, queue, size, pool, block->ptr);
|
||||
block->expandable_segment = remaining->expandable_segment;
|
||||
block->prev = remaining->prev;
|
||||
if (block->prev) {
|
||||
block->prev->next = block;
|
||||
@ -1097,15 +599,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
|
||||
|
||||
@ -1218,6 +711,15 @@ class XPUAllocator : public DeviceAllocator {
|
||||
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
DeviceStats getDeviceStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getStats();
|
||||
@ -1233,13 +735,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
assertValidDevice(dev);
|
||||
assertValidDevice(dev_to_access);
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_enable_peer_access(
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
@ -1298,10 +793,6 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
return allocator.enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
@ -25,10 +25,6 @@ C10_XPU_API void raw_delete(void* ptr);
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1647,8 +1616,6 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1682,7 +1649,6 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -256,25 +256,23 @@ class TestSDPA(NNTestCase):
|
||||
)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
_grad_q, _grad_k, _grad_v, _grad_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
|
||||
@ -416,13 +416,15 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
param_dtype=MixedPrecisionParam,
|
||||
reduce_dtype=torch.float32,
|
||||
)
|
||||
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
||||
replicate_config = {"mp_policy": mp_policy}
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
**replicate_config,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, **replicate_config)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
|
||||
return dp_model
|
||||
|
||||
# Apply same precision to reference model (without replicate)
|
||||
@ -580,11 +582,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_group = device_mesh["replicate"].get_group()
|
||||
@ -646,9 +648,10 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
mesh=dp_mesh,
|
||||
device_mesh=dp_mesh,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, mesh=dp_mesh)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh)
|
||||
return dp_model
|
||||
|
||||
def pipelined_models_parameters(start_layer, model):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,6 +14,7 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
@ -45,20 +46,35 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
@ -66,13 +82,27 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
@ -80,10 +110,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -141,14 +175,39 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self._test_reduce_dtype_fp32_reduce()
|
||||
self._test_reduce_dtype_bf16_reduce()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(self):
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -190,12 +249,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self,
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -260,8 +321,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(mlp, mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
@ -108,70 +108,84 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
if reshard_after_forward:
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
else:
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj)
|
||||
replicate(model[0].out_proj)
|
||||
replicate(model)
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
if reshard_after_forward is None:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
elif reshard_after_forward:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_dtensor_params(root_params)
|
||||
else:
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
model = MLP(8, device)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj)
|
||||
replicate(model.out_proj)
|
||||
replicate(model)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
||||
# need to iterate over the list multiple times
|
||||
@ -273,11 +287,14 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
[(7, 15), (15, 3)],
|
||||
[(16, 17), (17, 8)],
|
||||
],
|
||||
"use_shard_placement_fn": [False],
|
||||
},
|
||||
self._test_train_parity_single_group,
|
||||
)
|
||||
|
||||
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
@ -316,6 +333,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
@ -336,6 +354,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True], # save CI time
|
||||
"offload_policy": [
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
@ -352,6 +371,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
|
||||
def _test_train_parity_multi_group(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
offload_policy: OffloadPolicy,
|
||||
test_device_type: str,
|
||||
delay_after_forward: bool,
|
||||
@ -385,12 +405,13 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
fully_shard_fn = functools.partial(
|
||||
replicate,
|
||||
mesh=mesh,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for module in model.modules():
|
||||
@ -506,10 +527,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
Tests parity when running a module that participates multiple
|
||||
times in forward.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_multi_forward_module,
|
||||
)
|
||||
|
||||
self._test_multi_forward_module()
|
||||
|
||||
def _test_multi_forward_module(self):
|
||||
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
|
||||
class MultiForwardModule(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
@ -664,6 +687,7 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"checkpoint_impl": ["composable", "utils", "wrapper"],
|
||||
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
|
||||
"test_device_type": [device_type.type],
|
||||
@ -673,6 +697,7 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
|
||||
def _test_train_parity_with_activation_checkpointing(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
checkpoint_impl: str,
|
||||
module_grouping: str,
|
||||
test_device_type: str,
|
||||
@ -715,11 +740,12 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
# Apply Replicate
|
||||
device_mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
fsdp_kwargs = {
|
||||
"mesh": device_mesh,
|
||||
"reshard_after_forward": reshard_after_forward,
|
||||
"device_mesh": device_mesh,
|
||||
}
|
||||
if module_grouping == "mem_eff":
|
||||
assert model_args.n_layers == 3
|
||||
@ -783,6 +809,7 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
def test_train_parity_with_shared_params(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
},
|
||||
self._test_train_shared_params,
|
||||
@ -790,6 +817,7 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
|
||||
def _test_train_shared_params(
|
||||
self,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
@ -802,8 +830,8 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
if isinstance(module, TransformerBlock):
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
replicate(module, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
@ -840,11 +868,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
replicate_size = self.world_size
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -900,7 +928,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
mesh=mesh,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
@ -1011,8 +1040,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
@ -1116,8 +1145,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 2),
|
||||
mesh_dim_names=("dp_replicate", "tp"),
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@ -1125,6 +1154,7 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
@ -1135,11 +1165,12 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
@ -1166,8 +1197,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, mesh=dp_mesh)
|
||||
replicate(model, mesh=dp_mesh)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
@ -1198,9 +1229,11 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 2)
|
||||
self.assertEqual(len(p.placements), 2)
|
||||
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue("replicate" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
elif i % 2 == 1:
|
||||
self.assertTrue("fully_shard" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
]
|
||||
|
||||
global_mesh = self.init_replicate_tp_mesh()
|
||||
replicate_mesh = global_mesh["replicate"]
|
||||
replicate_mesh = global_mesh["replicate", "shard"]
|
||||
|
||||
for layer in layers:
|
||||
replicate(layer, mesh=replicate_mesh)
|
||||
replicate(layer, device_mesh=replicate_mesh)
|
||||
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
self.assertEqual(parameter.device_mesh.shape, (2, 1))
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_replicate_fsdp(self):
|
||||
@ -263,6 +263,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
run_subtests(
|
||||
self,
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 16, 17],
|
||||
},
|
||||
@ -272,6 +273,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
def _test_train_parity_2d_mlp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
):
|
||||
@ -285,12 +287,13 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, mesh=replicate_mesh)
|
||||
replicate(ref_model, device_mesh=replicate_shard_mesh)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
replicate_shard_mesh,
|
||||
use_activation_checkpointing,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
|
||||
@ -1,26 +1,16 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint._async_process_executor import (
|
||||
_ProcessBasedAsyncCheckpointExecutor,
|
||||
_ProcessGroupInitInfo,
|
||||
)
|
||||
from torch.distributed.checkpoint.api import CheckpointException
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_distributed import skip_if_win32
|
||||
from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -120,184 +110,47 @@ class TestAsyncProcessExecutor(DTensorTestBase):
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
)
|
||||
fut.result()
|
||||
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
self.assertIn(
|
||||
"fail_once policy triggered failure", str(fut.exception())
|
||||
)
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
fut.result()
|
||||
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
|
||||
|
||||
class TestAsyncProcessExecutorPrefixStore(TestCase):
|
||||
@skip_if_win32()
|
||||
@retry_on_connect_failures
|
||||
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
|
||||
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
|
||||
|
||||
test_state_dict = {
|
||||
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
master_addr = "localhost"
|
||||
master_port = str(common.find_free_port())
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
},
|
||||
):
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port"
|
||||
) as mock_get_free_port:
|
||||
dist.init_process_group(
|
||||
backend=dist.Backend.GLOO,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
self.assertIsNotNone(result)
|
||||
mock_get_free_port.assert_not_called()
|
||||
|
||||
|
||||
class TestProcessGroupInitInfo(DTensorTestBase):
|
||||
"""Test suite for _ProcessGroupInitInfo."""
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_default_pg(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo correctly initializes."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
|
||||
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
|
||||
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
|
||||
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
|
||||
self.assertEqual(pg_init_info.use_prefix_store, False)
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
|
||||
|
||||
# Flag enabled, addr/port correctly defined
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
},
|
||||
):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertTrue(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing port
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Missing addr
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Invalid port
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "a",
|
||||
},
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
|
||||
|
||||
# Env var set to 0
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing env var
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Invalid env var
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -415,15 +415,6 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
with DebugMode(record_stack_trace=True) as debug_mode:
|
||||
out = mod(inp).sum()
|
||||
out.backward()
|
||||
|
||||
sum_op = [
|
||||
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -1019,28 +1019,6 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
@with_comms
|
||||
def test_as_strided_identity(self):
|
||||
# Test calling as_strided with the same size/stride/offset as input tensor
|
||||
# This should be a no-op but currently fails
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 4, device=self.device_type)
|
||||
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# Get the current size, stride, and storage_offset
|
||||
size = dtensor.size()
|
||||
stride = dtensor.stride()
|
||||
storage_offset = dtensor.storage_offset()
|
||||
|
||||
# Call as_strided with the exact same parameters
|
||||
result = dtensor.as_strided(size, stride, storage_offset)
|
||||
|
||||
# The result should be identical to the input
|
||||
self.assertEqual(result.size(), dtensor.size())
|
||||
self.assertEqual(result.stride(), dtensor.stride())
|
||||
self.assertEqual(result.to_local(), dtensor.to_local())
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
index 1aa8e4e2897..bc177c2943e 100644
|
||||
index 1aa8e4e2897..94315fa68b4 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_heapq.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
@@ -1,3 +1,23 @@
|
||||
@ -35,7 +35,7 @@ index 1aa8e4e2897..bc177c2943e 100644
|
||||
def test_py_functions(self):
|
||||
for fname in func_names:
|
||||
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
|
||||
@@ -27,24 +47,12 @@ class TestModules(TestCase):
|
||||
@@ -27,24 +47,7 @@ class TestModules(TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@ -46,15 +46,12 @@ index 1aa8e4e2897..bc177c2943e 100644
|
||||
- # However, doctest can't easily find all docstrings in the module (loading
|
||||
- # it through import_fresh_module seems to confuse it), so we specifically
|
||||
- # create a finder which returns the doctests from the merge method.
|
||||
+@torch._dynamo.disable
|
||||
+def randrange(*args):
|
||||
+ return random.randrange(*args)
|
||||
|
||||
-
|
||||
- class HeapqMergeDocTestFinder:
|
||||
- def find(self, *args, **kwargs):
|
||||
- dtf = doctest.DocTestFinder()
|
||||
- return dtf.find(py_heapq.merge)
|
||||
|
||||
-
|
||||
- tests.addTests(doctest.DocTestSuite(py_heapq,
|
||||
- test_finder=HeapqMergeDocTestFinder()))
|
||||
- return tests
|
||||
@ -64,155 +61,7 @@ index 1aa8e4e2897..bc177c2943e 100644
|
||||
|
||||
def test_push_pop(self):
|
||||
# 1) Push 256 random numbers and pop them off, verifying all's OK.
|
||||
@@ -52,7 +60,8 @@ class TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
- item = random.random()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@@ -83,14 +92,16 @@ class TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
- heap = [random.random() for dummy in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@@ -113,7 +124,8 @@ class TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -126,7 +138,8 @@ class TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -163,8 +176,9 @@ class TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
- size = random.randrange(50)
|
||||
- data = [random.randrange(25) for i in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ size = randrange(50)
|
||||
+ data = [randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@@ -177,12 +191,13 @@ class TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
- for i in range(random.randrange(25)):
|
||||
- row = []
|
||||
- for j in range(random.randrange(100)):
|
||||
- tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
- row.append(tup)
|
||||
- inputs.append(row)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ for i in range(randrange(25)):
|
||||
+ row = []
|
||||
+ for j in range(randrange(100)):
|
||||
+ tup = random.choice('ABC'), randrange(-500, 500)
|
||||
+ row.append(tup)
|
||||
+ inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@@ -209,12 +224,14 @@ class TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
- class Int(int):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Int(int):
|
||||
+ pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
- stream = random.randrange(4)
|
||||
- x = random.randrange(500)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ stream = randrange(4)
|
||||
+ x = randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@@ -224,7 +241,8 @@ class TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@@ -233,7 +251,8 @@ class TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@@ -248,28 +267,29 @@ class TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
- class LT:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __lt__(self, other):
|
||||
- return self.x > other.x
|
||||
- class LE:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __le__(self, other):
|
||||
- return self.x >= other.x
|
||||
- data = [random.random() for i in range(100)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LT:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __lt__(self, other):
|
||||
+ return self.x > other.x
|
||||
+ class LE:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __le__(self, other):
|
||||
+ return self.x >= other.x
|
||||
+ data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
@@ -264,12 +267,12 @@ class TestHeap:
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
|
||||
|
||||
@ -227,7 +76,7 @@ index 1aa8e4e2897..bc177c2943e 100644
|
||||
module = c_heapq
|
||||
|
||||
|
||||
@@ -374,7 +394,7 @@ class SideEffectLT:
|
||||
@@ -374,7 +377,7 @@ class SideEffectLT:
|
||||
return self.value < other.value
|
||||
|
||||
|
||||
@ -236,48 +85,7 @@ index 1aa8e4e2897..bc177c2943e 100644
|
||||
|
||||
def test_non_sequence(self):
|
||||
for f in (self.module.heapify, self.module.heappop):
|
||||
@@ -435,10 +455,11 @@ class TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
- class EvilClass(int):
|
||||
- def __lt__(self, o):
|
||||
- heap.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EvilClass(int):
|
||||
+ def __lt__(self, o):
|
||||
+ heap.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@@ -446,15 +467,16 @@ class TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
- class h(int):
|
||||
- def __lt__(self, o):
|
||||
- list2.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class h(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list2.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
- class g(int):
|
||||
- def __lt__(self, o):
|
||||
- list1.clear()
|
||||
- return NotImplemented
|
||||
+ class g(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list1.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
@@ -464,13 +486,13 @@ class TestErrorHandling:
|
||||
@@ -464,13 +467,13 @@ class TestErrorHandling:
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
|
||||
|
||||
|
||||
@ -47,11 +47,6 @@ class TestModules(__TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def randrange(*args):
|
||||
return random.randrange(*args)
|
||||
|
||||
|
||||
class _TestHeap:
|
||||
|
||||
def test_push_pop(self):
|
||||
@ -60,8 +55,7 @@ class _TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
item = random.random()
|
||||
item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@ -92,16 +86,14 @@ class _TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@ -124,8 +116,7 @@ class _TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -138,8 +129,7 @@ class _TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -176,9 +166,8 @@ class _TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
size = randrange(50)
|
||||
data = [randrange(25) for i in range(size)]
|
||||
size = random.randrange(50)
|
||||
data = [random.randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@ -191,13 +180,12 @@ class _TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
for i in range(randrange(25)):
|
||||
row = []
|
||||
for j in range(randrange(100)):
|
||||
tup = random.choice('ABC'), randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
for i in range(random.randrange(25)):
|
||||
row = []
|
||||
for j in range(random.randrange(100)):
|
||||
tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@ -224,14 +212,12 @@ class _TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Int(int):
|
||||
pass
|
||||
class Int(int):
|
||||
pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
stream = randrange(4)
|
||||
x = randrange(500)
|
||||
stream = random.randrange(4)
|
||||
x = random.randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@ -241,8 +227,7 @@ class _TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@ -251,8 +236,7 @@ class _TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@ -267,18 +251,17 @@ class _TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
@ -455,11 +438,10 @@ class _TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@ -467,16 +449,15 @@ class _TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
|
||||
@ -427,29 +427,17 @@ from user code:
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return re.sub(
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
self.assertExpectedInline(
|
||||
post_munge(first_graph_break),
|
||||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -69,7 +69,6 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
constrain_unify,
|
||||
ConstraintViolationError,
|
||||
expect_true,
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
ShapeEnv,
|
||||
)
|
||||
@ -101,6 +100,7 @@ from torch.testing._internal.common_utils import (
|
||||
wrapDeterministicFlagAPITest,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
pytree_modules = {
|
||||
@ -13636,74 +13636,6 @@ instantiate_device_type_tests(
|
||||
)
|
||||
|
||||
|
||||
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul(self):
|
||||
def symbool_mul_fn(x_bool, sentinel):
|
||||
result = x_bool * sentinel
|
||||
return result
|
||||
|
||||
x_true = torch.tensor([True], device="cuda")
|
||||
x_false = torch.tensor([False], device="cuda")
|
||||
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
|
||||
eager_result_true = symbool_mul_fn(x_true, sentinel)
|
||||
eager_result_false = symbool_mul_fn(x_false, sentinel)
|
||||
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
|
||||
compiled_result_true = compiled_fn(x_true, sentinel)
|
||||
compiled_result_false = compiled_fn(x_false, sentinel)
|
||||
self.assertEqual(eager_result_true, compiled_result_true)
|
||||
self.assertEqual(eager_result_false, compiled_result_false)
|
||||
self.assertEqual(compiled_result_true.item(), 2.0)
|
||||
self.assertEqual(compiled_result_false.item(), 0.0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_guard_or_false(self):
|
||||
def symbool_guard_fn(a_bool_tensor, b):
|
||||
u0 = a_bool_tensor.item()
|
||||
# Make sure guard_or_false still handles SymBool produced by .item()
|
||||
if guard_or_false(u0):
|
||||
return b * 10
|
||||
else:
|
||||
return b * 100
|
||||
|
||||
compiled_guard_fn = torch.compile(
|
||||
symbool_guard_fn, backend="eager", dynamic=True
|
||||
)
|
||||
a_true = torch.tensor(True, device="cuda")
|
||||
a_false = torch.tensor(False, device="cuda")
|
||||
b = torch.randn(6, device="cuda")
|
||||
eager_res_true = symbool_guard_fn(a_true, b)
|
||||
compiled_res_true = compiled_guard_fn(a_true, b)
|
||||
self.assertEqual(eager_res_true, compiled_res_true)
|
||||
eager_res_false = symbool_guard_fn(a_false, b)
|
||||
compiled_res_false = compiled_guard_fn(a_false, b)
|
||||
self.assertEqual(eager_res_false, compiled_res_false)
|
||||
self.assertEqual(compiled_res_true, b * 10)
|
||||
self.assertEqual(compiled_res_false, b * 100)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul_does_not_fail(self):
|
||||
def fuzzed_program(arg_0, sentinel):
|
||||
var_node_2 = arg_0
|
||||
var_node_1 = torch.squeeze(var_node_2)
|
||||
var_node_0 = var_node_1.item()
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
|
||||
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
|
||||
args = (arg_0,) + (sentinel,)
|
||||
try:
|
||||
compiled_program = torch.compile(
|
||||
fuzzed_program, fullgraph=True, dynamic=True
|
||||
)
|
||||
compiled_program(*args)
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -1000,18 +1000,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_compiled_module_truthiness(self):
|
||||
# Test with empty ModuleList
|
||||
original_empty = nn.ModuleList()
|
||||
compiled_empty = torch.compile(original_empty)
|
||||
self.assertEqual(bool(original_empty), bool(compiled_empty))
|
||||
self.assertFalse(bool(compiled_empty))
|
||||
# Test with non-empty ModuleList
|
||||
original_filled = nn.ModuleList([nn.Linear(10, 5)])
|
||||
compiled_filled = torch.compile(original_filled)
|
||||
self.assertEqual(bool(original_filled), bool(compiled_filled))
|
||||
self.assertTrue(bool(compiled_filled))
|
||||
|
||||
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
|
||||
root = guard_manager_wrapper.root
|
||||
cloned_root = root.clone_manager(lambda x: True)
|
||||
|
||||
@ -751,29 +751,6 @@ class TestConstFold(TestCase):
|
||||
)
|
||||
self.assertIsNone(mod_folded.const_subgraph_module)
|
||||
|
||||
def test_const_fold_partial_graph(self):
|
||||
"""
|
||||
If a model graph is partially const folded,
|
||||
the non-const subgraph should be inlined back and erased.
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self, p):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
probs = torch.empty_permuted(x.shape, [0, 1])
|
||||
mask = torch.bernoulli(probs, 1 - self.p)
|
||||
return x * mask / (1 - self.p)
|
||||
|
||||
ep = torch.export.export(TestModule(0.4), (torch.randn(5, 10),))
|
||||
|
||||
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
|
||||
ep.module(), device_for_folded_attrs="cpu"
|
||||
)
|
||||
self._verify_const_fold_mod(mod_folded)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_fx.py")
|
||||
|
||||
@ -20,14 +20,8 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
skipIf,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_SLOW,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
||||
|
||||
|
||||
@ -388,11 +382,7 @@ class TestAnalysis(TestCase):
|
||||
|
||||
verify_triton(comp_omni)
|
||||
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
@ -477,7 +467,6 @@ class TestAnalysis(TestCase):
|
||||
"aten::cudnn_convolution",
|
||||
"aten::convolution",
|
||||
"aten::_convolution",
|
||||
"aten::convolution_overrideable",
|
||||
)
|
||||
)
|
||||
or "conv" in name
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -17,12 +16,9 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -33,7 +29,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -46,7 +42,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -62,7 +58,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -79,7 +75,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -87,12 +83,6 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -15,8 +15,9 @@ from torch.testing._internal.common_utils import (
|
||||
is_navi3_arch,
|
||||
parametrize,
|
||||
patch_test_members,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
@ -60,6 +61,11 @@ class TestDecomposeAddMM(torch.nn.Module):
|
||||
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
TEST_XPU,
|
||||
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
|
||||
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
|
||||
)
|
||||
@torch._inductor.config.patch(
|
||||
post_grad_fusion_options={
|
||||
"decompose_mm_pass": {},
|
||||
@ -138,7 +144,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -149,7 +155,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -198,7 +204,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -253,7 +259,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -298,7 +304,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -310,7 +316,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -368,7 +374,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -380,7 +386,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -404,7 +410,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -418,7 +424,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 0
|
||||
if HAS_GPU_AND_TRITON:
|
||||
if HAS_CUDA_AND_TRITON:
|
||||
expected_val = 1 if has_bias else 2
|
||||
|
||||
self.assertEqual(
|
||||
@ -441,8 +447,12 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
_, code = run_and_get_code(foo, input1, input2)
|
||||
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
if GPU_TYPE == "xpu":
|
||||
# only 1 kernel generated on the XPU stack
|
||||
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
|
||||
else:
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
|
||||
def test_check_device(self):
|
||||
m = 5
|
||||
@ -452,7 +462,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
input1 = torch.randn(m, k, device=GPU_TYPE)
|
||||
input2 = torch.randn(k, n, device=GPU_TYPE)
|
||||
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
|
||||
self.assertTrue(check_device(input1, input2))
|
||||
self.assertFalse(check_device(input1, input2, device="cpu"))
|
||||
|
||||
input1 = torch.randn(m, k)
|
||||
|
||||
@ -806,6 +806,8 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
def check(
|
||||
self, model, inp, dynamic_shapes=None, strict=False
|
||||
) -> torch.fx.GraphModule:
|
||||
if self.device == "xpu":
|
||||
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(
|
||||
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
||||
|
||||
@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
softmax_load_str in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -1826,14 +1826,9 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -35,6 +35,7 @@ from torch.cuda._memory_viz import (
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_create_scaling_case,
|
||||
HAS_WORKING_NVML,
|
||||
SM70OrLater,
|
||||
TEST_CUDNN,
|
||||
TEST_MULTIGPU,
|
||||
@ -4803,6 +4804,7 @@ print(torch.cuda.get_allocator_backend())
|
||||
def test_temperature(self):
|
||||
self.assertTrue(0 <= torch.cuda.temperature() <= 150)
|
||||
|
||||
@unittest.skipIf(not HAS_WORKING_NVML, "pynvml availble but broken")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "flaky for AMD gpu")
|
||||
@unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available")
|
||||
def test_device_memory_used(self):
|
||||
@ -7413,140 +7415,6 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,7 +771,6 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -781,7 +780,6 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
@ -107,76 +107,11 @@ def tensor_to_scale_block(
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
# if amax == 0, entire block = 0, set scale = 0 to ensure elements are
|
||||
# zero'd out correctly (and remove bad effects of / 0)
|
||||
scale[amax == 0] = 0
|
||||
|
||||
# Scale x, noting that blocks where amax == 0 are explicitly 0 now.
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
|
||||
# if amax == 0, all values in the block are 0, scale=0
|
||||
# but we need scale.reciprocal later, which breaks when scale=0...
|
||||
# So. Replace 0 -> 1 in the scale so we don't break things later.
|
||||
# Elements are already zeroed, so don't actually care what the scale
|
||||
# is, as long as it's not inf/nan.
|
||||
scale[scale == 0] = 1.
|
||||
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def hp_from_128x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
M, K = orig_shape
|
||||
x_lp = x_lp.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_128x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
M, K = orig_shape
|
||||
x = x.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_lp = x * x_scale
|
||||
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
def hp_from_1x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
x_lp = x_lp.reshape(x_lp.shape[0], x_lp.shape[-1] // 128, 128)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale.unsqueeze(-1)
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_1x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(x.shape[0], x.shape[-1] // 128, 128)
|
||||
x_lp = x * x_scale.unsqueeze(-1)
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
# cublas requires specific padding for 128x128 scales, see:
|
||||
# https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
# Notably L = ceil_div(K, 128),
|
||||
# L4 = round_up(L, 4),
|
||||
# and then for A/B the shape must be
|
||||
# scale: [L4, ceil_div({M,N}, 128) and K/L/L4-major in memory.
|
||||
#
|
||||
# This routine pads L -> L4
|
||||
def _pad_128x128_scales(scale: torch.Tensor) -> (torch.Tensor, int):
|
||||
# scale is either [L4, ceil_div(M, 128)] or [L4, ceil_div(N, 128)], stride: [1, L4]
|
||||
# However, we get passed it as [ceil_div(M, 128), L] or [ceil_div(N, 128), L]
|
||||
# so check inner dim % 4, and pad if necessary
|
||||
pad_amount = scale.shape[-1] % 4
|
||||
|
||||
if pad_amount == 0:
|
||||
return scale, 0
|
||||
else:
|
||||
pad_amount = 4 - pad_amount
|
||||
return pad(scale, (0, pad_amount), "constant", 0), pad_amount
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
@ -209,36 +144,42 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX formats
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MX w/swizzle (NVIDIA)
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MX w/o swizzle (AMD)
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1311,6 +1252,7 @@ class TestFP8Matmul(TestCase):
|
||||
else:
|
||||
test()
|
||||
|
||||
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
@ -1319,224 +1261,59 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
@parametrize("M,N,K", [
|
||||
# Nice size
|
||||
(256, 768, 512),
|
||||
# Requires padding for 128x128 scale
|
||||
(384, 128, 1280),
|
||||
# M=N=K for eyes test
|
||||
(512, 512, 512),
|
||||
])
|
||||
@parametrize("test_case", [
|
||||
"x_eye_b_eye",
|
||||
"x_ones_y_ones_calc_scales",
|
||||
"x_ones_y_ones_set_scales",
|
||||
"x_ones_y_ones_modify_scales",
|
||||
"data_random_scales_one",
|
||||
"data_random_calc_scales",
|
||||
])
|
||||
def test_scaled_mm_block_wise_numerics(self, output_dtype, lhs_block, rhs_block, M, N, K, test_case):
|
||||
"""
|
||||
subsume test_scaled_mm_vs_emulated_block_wise for random inputs, random scales,
|
||||
do some other functional tests as well.
|
||||
|
||||
# Inputs (as generated are):
|
||||
# A: [M, K]
|
||||
# B: [N, K]
|
||||
# then scales are, for the 3 combinations:
|
||||
# 1x128 x 1x128:
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [N, K // 128], stride: [1, N] -> scale.t().contiguous().t()
|
||||
# 1x128 x 128x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [L4, N // 128], stride: [1, L4] -> scale.t()
|
||||
# 128x128 x 1x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [L4, M // 128], stride: [1, L4]
|
||||
# Bs: [N, K // 128], stride: [1, N]
|
||||
"""
|
||||
@parametrize("M,N,K", [(256, 768, 512)])
|
||||
@with_tf32_off
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _adjust_lhs_scale(x_fp8, x_scales, lhs_block):
|
||||
M, K = x_fp8.shape
|
||||
x_scales_original = x_scales.clone()
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
x_hp = hp_from_1x128(x_fp8, x_scales_original)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
x_hp = hp_from_128x128(x_fp8, x_scales_original)
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
return x_hp, lhs_recipe, x_scales, x_scales_original
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
def _adjust_rhs_scale(y_fp8, y_scales, rhs_block):
|
||||
N, K = y_fp8.shape
|
||||
y_scales_original = y_scales.clone()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
y_hp = hp_from_1x128(y_fp8, y_scales_original)
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
y_hp = hp_from_128x128(y_fp8, y_scales_original)
|
||||
|
||||
return y_hp, rhs_recipe, y_scales, y_scales_original
|
||||
|
||||
def _build_lhs(x, lhs_block):
|
||||
M, K = x.shape
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
x_scales_original = x_scales
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
|
||||
return x_hp, x_recipe, x_fp8, x_scales, x_scales_original
|
||||
|
||||
def _build_rhs(y, rhs_block):
|
||||
N, K = y.shape
|
||||
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
|
||||
return y_hp, y_recipe, y_fp8, y_scales, y_scales_original
|
||||
|
||||
def _run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original):
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales.reciprocal(),
|
||||
scale_recipe_a=x_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales.reciprocal(),
|
||||
scale_recipe_b=y_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_emulated.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated.to(output_dtype), atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
def _build_constant_scale(t, block, val):
|
||||
M, K = t.shape
|
||||
|
||||
if block == 1:
|
||||
scale_shape = M, K // 128
|
||||
else:
|
||||
scale_shape = M // 128, K // 128
|
||||
|
||||
scale = torch.full(scale_shape, val, device='cuda')
|
||||
|
||||
return scale
|
||||
|
||||
def hp_to_scaled(t, scale, block):
|
||||
if block == 1:
|
||||
return hp_to_1x128(t, scale)
|
||||
else:
|
||||
return hp_to_128x128(t, scale)
|
||||
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
|
||||
if test_case == "x_eye_b_eye":
|
||||
if M != K or M != N:
|
||||
return unittest.skip("a_eye_b_eye only defined for M = N = K")
|
||||
x = torch.eye(M, device='cuda')
|
||||
y = torch.eye(M, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case == "x_ones_y_ones_calc_scales":
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case in ["x_ones_y_ones_set_scales", "x_ones_y_ones_modify_scales"]:
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
if "modify" in test_case:
|
||||
x_scales[0, 0] = 4.
|
||||
y_scales[-1, -1] = 4.
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_scales_one":
|
||||
x = torch.randint(0, 255, (M, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
y = torch.randint(0, 255, (N, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_calc_scales":
|
||||
# Note: Old test_scaled_mm_vs_emulated_block_wise test case
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype) * 1e-3
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
raise ValueError("Unknown test-case passed")
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
|
||||
_run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@ -1558,30 +1335,18 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
x_scales_original = x_scales
|
||||
y_scales_original = y_scales
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
|
||||
# Verify that actual F8 mm doesn't error
|
||||
scaled_mm_wrap(
|
||||
@ -1589,20 +1354,13 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales,
|
||||
scale_recipe_a=lhs_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales,
|
||||
scale_b=y_scales.t(),
|
||||
scale_recipe_b=rhs_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Verify that emulated F8 mm doesn't error
|
||||
mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@ -1862,7 +1620,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1876,12 +1634,8 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -2059,7 +1813,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
@ -47,18 +47,11 @@ def get_all_examples():
|
||||
"import io",
|
||||
"import itertools",
|
||||
"",
|
||||
"from typing import Any, ClassVar, Generic, List, Tuple, Union",
|
||||
"from typing_extensions import Literal, get_origin, TypeAlias",
|
||||
"T: TypeAlias = object",
|
||||
"",
|
||||
"import numpy",
|
||||
"",
|
||||
"import torch",
|
||||
"import torch.nn.functional as F",
|
||||
"",
|
||||
"from typing_extensions import ParamSpec as _ParamSpec",
|
||||
"ParamSpec = _ParamSpec",
|
||||
"",
|
||||
# for requires_grad_ example
|
||||
# NB: We are parsing this file as Python 2, so we must use
|
||||
# Python 2 type annotation syntax
|
||||
|
||||
115
test/test_xpu.py
115
test/test_xpu.py
@ -14,8 +14,10 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyXPU,
|
||||
OpDTypes,
|
||||
ops,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import ops_and_refs
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -72,8 +74,6 @@ _xpu_computation_ops = [
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpu(TestCase):
|
||||
expandable_segments = False
|
||||
|
||||
def test_device_behavior(self):
|
||||
current_device = torch.xpu.current_device()
|
||||
torch.xpu.set_device(current_device)
|
||||
@ -385,6 +385,56 @@ if __name__ == "__main__":
|
||||
torch.xpu.set_rng_state(g_state0)
|
||||
self.assertEqual(2024, torch.xpu.initial_seed())
|
||||
|
||||
@onlyXPU
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@onlyXPU
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_serialization_array_with_storage(self):
|
||||
x = torch.randn(5, 5).xpu()
|
||||
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
|
||||
@ -420,8 +470,6 @@ if __name__ == "__main__":
|
||||
self.assertEqual(copy.get_device(), original.get_device())
|
||||
|
||||
def test_out_of_memory(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
tensor = torch.zeros(1024, device="xpu") # noqa: F841
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
|
||||
@ -431,8 +479,6 @@ if __name__ == "__main__":
|
||||
torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu")
|
||||
|
||||
def test_raises_oom(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
torch.xpu.memory.empty_cache()
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")
|
||||
@ -545,7 +591,7 @@ if __name__ == "__main__":
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(
|
||||
@skipXPUIf(
|
||||
int(torch.version.xpu) < 20250000,
|
||||
"Test requires SYCL compiler version 2025.0.0 or newer.",
|
||||
)
|
||||
@ -593,8 +639,6 @@ if __name__ == "__main__":
|
||||
self.assertTrue(b"libsycl.so" in result)
|
||||
|
||||
def test_dlpack_conversion(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping DLPack test for expandable segments allocator.")
|
||||
x = make_tensor((5,), dtype=torch.float32, device="xpu")
|
||||
if IS_WINDOWS and int(torch.version.xpu) < 20250000:
|
||||
with self.assertRaisesRegex(
|
||||
@ -608,58 +652,7 @@ if __name__ == "__main__":
|
||||
self.assertEqual(z, x)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpuOps(TestCase):
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True)
|
||||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
# Owner(s): ["module: intel"]
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from test_xpu import TestXpu, TestXpuOpsXPU # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.xpu.is_available() and not IS_WINDOWS:
|
||||
get_disabled_tests(".")
|
||||
|
||||
torch._C._accelerator_setAllocatorSettings("expandable_segments:True")
|
||||
TestXpu.expandable_segments = True
|
||||
|
||||
run_tests()
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from multiprocessing import cpu_count, Pool
|
||||
@ -19,6 +20,19 @@ from tools.stats.upload_stats_lib import (
|
||||
)
|
||||
|
||||
|
||||
def should_upload_full_test_run(head_branch: str | None, head_repository: str) -> bool:
|
||||
"""Return True if we should upload the full test_run dataset.
|
||||
|
||||
Rules:
|
||||
- Only for the main repository (pytorch/pytorch)
|
||||
- If head_branch is 'main', or a tag of form 'trunk/{40-hex-sha}'
|
||||
"""
|
||||
is_trunk_tag = bool(re.fullmatch(r"trunk/[0-9a-fA-F]{40}", (head_branch or "")))
|
||||
return head_repository == "pytorch/pytorch" and (
|
||||
head_branch == "main" or is_trunk_tag
|
||||
)
|
||||
|
||||
|
||||
def parse_xml_report(
|
||||
tag: str,
|
||||
report: Path,
|
||||
@ -287,7 +301,8 @@ if __name__ == "__main__":
|
||||
remove_nan_inf(failed_tests_cases),
|
||||
)
|
||||
|
||||
if args.head_branch == "main" and args.head_repository == "pytorch/pytorch":
|
||||
# Upload full test_run only for trusted refs (main or trunk/{sha} tags)
|
||||
if should_upload_full_test_run(args.head_branch, args.head_repository):
|
||||
# For jobs on main branch, upload everything.
|
||||
upload_workflow_stats_to_s3(
|
||||
args.workflow_run_id,
|
||||
|
||||
28
tools/test/test_upload_gate.py
Normal file
28
tools/test/test_upload_gate.py
Normal file
@ -0,0 +1,28 @@
|
||||
import unittest
|
||||
|
||||
from tools.stats.upload_test_stats import should_upload_full_test_run
|
||||
|
||||
|
||||
class TestUploadGate(unittest.TestCase):
|
||||
def test_main_branch_on_pytorch_repo(self) -> None:
|
||||
self.assertTrue(should_upload_full_test_run("main", "pytorch/pytorch"))
|
||||
|
||||
def test_trunk_tag_valid_sha_on_pytorch_repo(self) -> None:
|
||||
sha = "a" * 40
|
||||
self.assertTrue(should_upload_full_test_run(f"trunk/{sha}", "pytorch/pytorch"))
|
||||
|
||||
def test_trunk_tag_invalid_sha_on_pytorch_repo(self) -> None:
|
||||
# Not 40 hex chars
|
||||
self.assertFalse(should_upload_full_test_run("trunk/12345", "pytorch/pytorch"))
|
||||
|
||||
def test_non_main_branch_on_pytorch_repo(self) -> None:
|
||||
self.assertFalse(
|
||||
should_upload_full_test_run("feature-branch", "pytorch/pytorch")
|
||||
)
|
||||
|
||||
def test_main_branch_on_fork_repo(self) -> None:
|
||||
self.assertFalse(should_upload_full_test_run("main", "someone/fork"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -663,9 +663,6 @@ class SymFloat:
|
||||
def __float__(self):
|
||||
return self.node.guard_float("", 0)
|
||||
|
||||
def __int__(self):
|
||||
return self.__trunc__().__int__()
|
||||
|
||||
# Symbolic power does NOT work with negative base, this is to avoid
|
||||
# potential complex outputs
|
||||
def __pow__(self, other):
|
||||
@ -814,15 +811,6 @@ class SymBool:
|
||||
# Force specialization
|
||||
return hash(builtins.bool(self))
|
||||
|
||||
def __sym_float__(self):
|
||||
"""
|
||||
Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
|
||||
Called by torch.sym_float() when casting SymBool to float.
|
||||
"""
|
||||
from torch.fx.experimental.sym_node import wrap_node
|
||||
|
||||
return wrap_node(self.node.sym_float())
|
||||
|
||||
|
||||
def sym_not(a):
|
||||
r"""SymInt-aware utility for logical negation.
|
||||
|
||||
@ -739,12 +739,6 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -395,13 +395,6 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._initialize()
|
||||
self.training = self._orig_mod.training
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Proxy the len call to the original module
|
||||
if isinstance(self._orig_mod, Sized):
|
||||
return len(self._orig_mod)
|
||||
# Mimic python's default behavior for objects without a length
|
||||
raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
|
||||
|
||||
def _initialize(self) -> None:
|
||||
# Do this stuff in constructor to lower overhead slightly
|
||||
if isinstance(self.dynamo_ctx, DisableContext):
|
||||
|
||||
@ -1734,14 +1734,6 @@
|
||||
}
|
||||
],
|
||||
"GB0175": [
|
||||
{
|
||||
"Gb_type": "builtin isinstance() cannot determine type of argument",
|
||||
"Context": "isinstance({arg}, {isinstance_type_var})",
|
||||
"Explanation": "Dynamo doesn't have a rule to determine the type of argument {arg}",
|
||||
"Hints": [
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"Gb_type": "builtin isinstance() cannot determine type of argument",
|
||||
"Context": "isinstance({arg}, {isinstance_type})",
|
||||
@ -2923,19 +2915,5 @@
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0287": [
|
||||
{
|
||||
"Gb_type": "unsupported type.__dict__['__annotations__'].__get__ call",
|
||||
"Context": "call_function {self}, args: {args}, kwargs: {kwargs}",
|
||||
"Explanation": "`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ on a single constant argument (i.e. a type).",
|
||||
"Hints": [
|
||||
"Make sure your call to type.__dict__['__annotations__'] only has ",
|
||||
"one positional argument (no keyword arguments).",
|
||||
"Make sure the argument to type.__dict__['__annotations__'] is a constant ",
|
||||
"(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
"""
|
||||
Python polyfills for heapq
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import importlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Partially copied from CPython test/support/import_helper.py
|
||||
# https://github.com/python/cpython/blob/bb8791c0b75b5970d109e5557bfcca8a578a02af/Lib/test/support/import_helper.py
|
||||
def _save_and_remove_modules(names: set[str]) -> dict[str, ModuleType]:
|
||||
orig_modules = {}
|
||||
prefixes = tuple(name + "." for name in names)
|
||||
for modname in list(sys.modules):
|
||||
if modname in names or modname.startswith(prefixes):
|
||||
orig_modules[modname] = sys.modules.pop(modname)
|
||||
return orig_modules
|
||||
|
||||
|
||||
def import_fresh_module(name: str, blocked: list[str]) -> ModuleType:
|
||||
# Keep track of modules saved for later restoration as well
|
||||
# as those which just need a blocking entry removed
|
||||
names = {name, *blocked}
|
||||
orig_modules = _save_and_remove_modules(names)
|
||||
for modname in blocked:
|
||||
sys.modules[modname] = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
finally:
|
||||
_save_and_remove_modules(names)
|
||||
sys.modules.update(orig_modules)
|
||||
|
||||
|
||||
# Import the pure Python heapq module, blocking the C extension
|
||||
py_heapq = import_fresh_module("heapq", blocked=["_heapq"])
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_heapify_max",
|
||||
"_heappop_max",
|
||||
"_heapreplace_max",
|
||||
"heapify",
|
||||
"heappop",
|
||||
"heappush",
|
||||
"heappushpop",
|
||||
"heapreplace",
|
||||
"merge",
|
||||
"nlargest",
|
||||
"nsmallest",
|
||||
]
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heapify_max)
|
||||
def _heapify_max(heap: list[_T], /) -> None:
|
||||
return py_heapq._heapify_max(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heappop_max) # type: ignore[attr-defined]
|
||||
def _heappop_max(heap: list[_T]) -> _T:
|
||||
return py_heapq._heappop_max(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heapreplace_max) # type: ignore[attr-defined]
|
||||
def _heapreplace_max(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq._heapreplace_max(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heapify)
|
||||
def heapify(heap: list[_T], /) -> None:
|
||||
return py_heapq.heapify(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappop)
|
||||
def heappop(heap: list[_T], /) -> _T:
|
||||
return py_heapq.heappop(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappush)
|
||||
def heappush(heap: list[_T], item: _T) -> None:
|
||||
return py_heapq.heappush(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappushpop)
|
||||
def heappushpop(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq.heappushpop(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heapreplace)
|
||||
def heapreplace(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq.heapreplace(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.merge) # type: ignore[arg-type]
|
||||
def merge(*iterables, key=None, reverse=False): # type: ignore[no-untyped-def]
|
||||
return py_heapq.merge(*iterables, key=key, reverse=reverse)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.nlargest) # type: ignore[arg-type]
|
||||
def nlargest(n, iterable, key=None): # type: ignore[no-untyped-def]
|
||||
return py_heapq.nlargest(n, iterable, key=key)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.nsmallest) # type: ignore[arg-type]
|
||||
def nsmallest(n, iterable, key=None): # type: ignore[no-untyped-def]
|
||||
return py_heapq.nsmallest(n, iterable, key=key)
|
||||
@ -405,7 +405,6 @@ isolate_fails_code_str = None
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
kernel._fn_name
|
||||
if isinstance(kernel, JITFunction)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
else kernel.fn._fn_name
|
||||
)
|
||||
fn_name = fn_name.split(".")[-1]
|
||||
|
||||
@ -218,7 +218,7 @@ class CPythonTestCase(TestCase):
|
||||
if m:
|
||||
test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
|
||||
py_ver = sys.version_info[:2]
|
||||
if py_ver != test_py_ver:
|
||||
if py_ver < test_py_ver:
|
||||
expected = ".".join(map(str, test_py_ver))
|
||||
got = ".".join(map(str, py_ver))
|
||||
raise unittest.SkipTest(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -61,11 +61,7 @@ from ..utils import (
|
||||
raise_args_mismatch,
|
||||
tuple_methods,
|
||||
)
|
||||
from .base import (
|
||||
AsPythonConstantNotImplementedError,
|
||||
raise_type_error_exc,
|
||||
VariableTracker,
|
||||
)
|
||||
from .base import raise_type_error_exc, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
||||
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
|
||||
@ -1264,38 +1260,6 @@ class MethodWrapperVariable(VariableTracker):
|
||||
return variables.BuiltinVariable(object).call_method(
|
||||
tx, wrapper_name, [self_obj, *args], kwargs
|
||||
)
|
||||
elif (
|
||||
sys.version_info >= (3, 14)
|
||||
# for some reason, even if the below check passes,
|
||||
# self.method_wrapper may not be the same as type.__dict__["__annotations__"].__get__
|
||||
and self_obj is type.__dict__["__annotations__"]
|
||||
and wrapper_name == "__get__"
|
||||
):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if len(args) == 1 and not kwargs:
|
||||
try:
|
||||
return SourcelessBuilder.create(
|
||||
tx, self.method_wrapper(args[0].as_python_constant())
|
||||
)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
except AsPythonConstantNotImplementedError:
|
||||
pass
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="unsupported type.__dict__['__annotations__'].__get__ call",
|
||||
context=f"call_function {self}, args: {args}, kwargs: {kwargs}",
|
||||
explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ "
|
||||
"on a single constant argument (i.e. a type).",
|
||||
hints=[
|
||||
"Make sure your call to type.__dict__['__annotations__'] only has "
|
||||
"one positional argument (no keyword arguments).",
|
||||
"Make sure the argument to type.__dict__['__annotations__'] is a constant "
|
||||
"(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ import operator
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -632,7 +631,7 @@ class TensorVariable(VariableTracker):
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: Sequence[VariableTracker],
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
@ -29,7 +29,6 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING
|
||||
|
||||
@ -723,12 +722,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: Sequence[VariableTracker],
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# This code block implements inlining the __torch_function__ override
|
||||
# of `call_method`.
|
||||
tf_args = [self] + list(args)
|
||||
tf_args = [self] + args
|
||||
if can_dispatch_torch_function(tx, tf_args, kwargs):
|
||||
import torch
|
||||
|
||||
|
||||
@ -179,9 +179,6 @@ def aot_stage1_graph_capture(
|
||||
)
|
||||
)
|
||||
|
||||
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
|
||||
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
|
||||
|
||||
return AOTGraphCapture(
|
||||
wrappers=wrappers,
|
||||
graph_module=graph,
|
||||
|
||||
@ -7,9 +7,9 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, TypeAlias
|
||||
from typing import Any, Optional, Sequence, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
@ -264,7 +264,6 @@ def generate_ttir(
|
||||
|
||||
assert isinstance(kernel, JITFunction)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
context = triton._C.libtriton.ir.context()
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend = triton.compiler.compiler.make_backend(target)
|
||||
@ -306,7 +305,6 @@ def generate_ttir(
|
||||
base_tensor = torch.empty(
|
||||
[elements_per_dim] * len(block_shape), dtype=a.dtype
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
|
||||
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
|
||||
with torch._C._DisableTorchDispatch():
|
||||
@ -370,7 +368,6 @@ def generate_ttir(
|
||||
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend_ = triton.compiler.compiler.make_backend(target)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return backend_.get_attrs_descriptor(args, kernel.params)
|
||||
else:
|
||||
assert (
|
||||
@ -387,7 +384,6 @@ def generate_ttir(
|
||||
except TypeError: # Unknown arg `specialize_extra`
|
||||
# Older versions of Triton take specialize_extra as an arg to specialize_impl
|
||||
specialize_impl = functools.partial(
|
||||
# pyrefly: ignore # missing-argument
|
||||
triton.runtime.jit.create_specialize_impl(),
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
)
|
||||
@ -472,7 +468,6 @@ def generate_ttir(
|
||||
if i not in constexprs
|
||||
}
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
triton._C.libtriton.ir.load_dialects(context)
|
||||
backend.load_dialects(context)
|
||||
|
||||
@ -482,29 +477,22 @@ def generate_ttir(
|
||||
# backward compatibility here.
|
||||
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
|
||||
get_codegen_implementation_sig_params = len(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
inspect.signature(backend.get_codegen_implementation).parameters
|
||||
)
|
||||
if make_ir_sig_params == 2:
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, context)
|
||||
elif make_ir_sig_params == 3:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation()
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, codegen_fns, context)
|
||||
elif make_ir_sig_params == 4:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
|
||||
else:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
|
||||
if not ttir_module.verify():
|
||||
raise RuntimeError("Verification for TTIR module has failed")
|
||||
@ -1114,7 +1102,6 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor
|
||||
|
||||
block_shape = stable_meta[0]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
|
||||
|
||||
# move as many positional arguments from dicts to args as we
|
||||
@ -1671,7 +1658,6 @@ class TritonHOPifier:
|
||||
"Passing multiple @triton.autotune decorators is not supported. "
|
||||
"Please use a single @triton.autotune decorator instead."
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
iter_kernel = iter_kernel.fn
|
||||
|
||||
# Process the @triton.heuristics decorator:
|
||||
@ -1882,7 +1868,6 @@ class TritonHOPifier:
|
||||
|
||||
# Both for grid's meta as well as for the kernel, we need combined
|
||||
# args and kwargs combined and normalized
|
||||
# pyrefly: ignore # missing-attribute
|
||||
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
|
||||
|
||||
# precompute the grid for the kernel
|
||||
@ -2076,7 +2061,6 @@ class TraceableTritonKernelWrapper:
|
||||
kernel_idx: Optional[int],
|
||||
grid: Optional["TritonGridType"],
|
||||
) -> None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.kernel = None
|
||||
self.grid = None
|
||||
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
|
||||
|
||||
@ -2,9 +2,8 @@ import json
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
|
||||
@ -76,9 +75,7 @@ def _slow_conv2d_adapter(
|
||||
return conv_adapter(tuple(tmp), tuple(tmp2))
|
||||
|
||||
|
||||
@register_adapter(
|
||||
["convolution", "_convolution", "cudnn_convolution", "convolution_overrideable"]
|
||||
)
|
||||
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
|
||||
def conv_adapter(
|
||||
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from functools import partial
|
||||
from time import time, time_ns
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||
@ -60,8 +60,6 @@ from torch.utils._triton import has_triton_package
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch._inductor.runtime.hints import HalideMeta
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -14,10 +14,10 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ctypes import byref, c_size_t, c_void_p, CDLL
|
||||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
|
||||
@ -2,10 +2,10 @@ import asyncio
|
||||
import sys
|
||||
import weakref
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Generator, Iterator
|
||||
from collections.abc import Awaitable, Coroutine, Generator, Iterator
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from contextvars import Context
|
||||
from typing import Any, Optional, Protocol, TypeVar
|
||||
from typing import Any, Callable, Optional, Protocol, TypeVar
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
|
||||
@ -34,7 +34,7 @@ from pathlib import Path
|
||||
from tempfile import _TemporaryFileWrapper
|
||||
from time import time, time_ns
|
||||
from types import ModuleType
|
||||
from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
@ -126,7 +126,7 @@ if config.is_fbcode():
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Generator, KeysView, Sequence
|
||||
from collections.abc import Generator, KeysView, Sequence
|
||||
from concurrent.futures import Future
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
@ -17,6 +17,7 @@ from enum import auto, Enum
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
ClassVar,
|
||||
Generic,
|
||||
@ -70,7 +71,7 @@ from ..virtualized import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterator, MutableMapping, Sequence
|
||||
from collections.abc import Iterator, MutableMapping, Sequence
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ import operator
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
@ -2,9 +2,8 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, Optional, TypeVar
|
||||
from typing import Any, Callable, cast, Optional, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user