mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
Compare commits
12 Commits
ciflow/ind
...
bf/bug-sta
| Author | SHA1 | Date | |
|---|---|---|---|
| 8cef91fb74 | |||
| 527b1109a8 | |||
| 3144713325 | |||
| eefa16342c | |||
| d02f68f484 | |||
| 68eb55c4b2 | |||
| 8d4b8ab430 | |||
| afd50bdd29 | |||
| 56dfd4c74b | |||
| 24db5c4451 | |||
| cc8bfd1206 | |||
| c45b156605 |
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.2",
|
||||
"13.0": "13.0.0",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -41,16 +39,6 @@ namespace {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -59,6 +59,24 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -806,7 +797,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
mat_a.sizes()[1],
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
softmax_load_str in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -15280,7 +15280,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15293,7 +15293,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1826,9 +1826,14 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -7413,6 +7413,140 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,6 +771,7 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -780,6 +781,7 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
@ -107,11 +107,76 @@ def tensor_to_scale_block(
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
# if amax == 0, entire block = 0, set scale = 0 to ensure elements are
|
||||
# zero'd out correctly (and remove bad effects of / 0)
|
||||
scale[amax == 0] = 0
|
||||
|
||||
# Scale x, noting that blocks where amax == 0 are explicitly 0 now.
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
|
||||
# if amax == 0, all values in the block are 0, scale=0
|
||||
# but we need scale.reciprocal later, which breaks when scale=0...
|
||||
# So. Replace 0 -> 1 in the scale so we don't break things later.
|
||||
# Elements are already zeroed, so don't actually care what the scale
|
||||
# is, as long as it's not inf/nan.
|
||||
scale[scale == 0] = 1.
|
||||
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def hp_from_128x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
M, K = orig_shape
|
||||
x_lp = x_lp.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_128x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
M, K = orig_shape
|
||||
x = x.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_lp = x * x_scale
|
||||
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
def hp_from_1x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
x_lp = x_lp.reshape(x_lp.shape[0], x_lp.shape[-1] // 128, 128)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale.unsqueeze(-1)
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_1x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(x.shape[0], x.shape[-1] // 128, 128)
|
||||
x_lp = x * x_scale.unsqueeze(-1)
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
# cublas requires specific padding for 128x128 scales, see:
|
||||
# https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
# Notably L = ceil_div(K, 128),
|
||||
# L4 = round_up(L, 4),
|
||||
# and then for A/B the shape must be
|
||||
# scale: [L4, ceil_div({M,N}, 128) and K/L/L4-major in memory.
|
||||
#
|
||||
# This routine pads L -> L4
|
||||
def _pad_128x128_scales(scale: torch.Tensor) -> (torch.Tensor, int):
|
||||
# scale is either [L4, ceil_div(M, 128)] or [L4, ceil_div(N, 128)], stride: [1, L4]
|
||||
# However, we get passed it as [ceil_div(M, 128), L] or [ceil_div(N, 128), L]
|
||||
# so check inner dim % 4, and pad if necessary
|
||||
pad_amount = scale.shape[-1] % 4
|
||||
|
||||
if pad_amount == 0:
|
||||
return scale, 0
|
||||
else:
|
||||
pad_amount = 4 - pad_amount
|
||||
return pad(scale, (0, pad_amount), "constant", 0), pad_amount
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
@ -144,42 +209,36 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# MX formats
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
# MX w/swizzle (NVIDIA)
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
# MX w/o swizzle (AMD)
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1252,7 +1311,6 @@ class TestFP8Matmul(TestCase):
|
||||
else:
|
||||
test()
|
||||
|
||||
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
@ -1261,59 +1319,224 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
@parametrize("M,N,K", [(256, 768, 512)])
|
||||
@with_tf32_off
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
|
||||
@parametrize("M,N,K", [
|
||||
# Nice size
|
||||
(256, 768, 512),
|
||||
# Requires padding for 128x128 scale
|
||||
(384, 128, 1280),
|
||||
# M=N=K for eyes test
|
||||
(512, 512, 512),
|
||||
])
|
||||
@parametrize("test_case", [
|
||||
"x_eye_b_eye",
|
||||
"x_ones_y_ones_calc_scales",
|
||||
"x_ones_y_ones_set_scales",
|
||||
"x_ones_y_ones_modify_scales",
|
||||
"data_random_scales_one",
|
||||
"data_random_calc_scales",
|
||||
])
|
||||
def test_scaled_mm_block_wise_numerics(self, output_dtype, lhs_block, rhs_block, M, N, K, test_case):
|
||||
"""
|
||||
subsume test_scaled_mm_vs_emulated_block_wise for random inputs, random scales,
|
||||
do some other functional tests as well.
|
||||
|
||||
# Inputs (as generated are):
|
||||
# A: [M, K]
|
||||
# B: [N, K]
|
||||
# then scales are, for the 3 combinations:
|
||||
# 1x128 x 1x128:
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [N, K // 128], stride: [1, N] -> scale.t().contiguous().t()
|
||||
# 1x128 x 128x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [L4, N // 128], stride: [1, L4] -> scale.t()
|
||||
# 128x128 x 1x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [L4, M // 128], stride: [1, L4]
|
||||
# Bs: [N, K // 128], stride: [1, N]
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
def _adjust_lhs_scale(x_fp8, x_scales, lhs_block):
|
||||
M, K = x_fp8.shape
|
||||
x_scales_original = x_scales.clone()
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
x_hp = hp_from_1x128(x_fp8, x_scales_original)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
x_hp = hp_from_128x128(x_fp8, x_scales_original)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
return x_hp, lhs_recipe, x_scales, x_scales_original
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
def _adjust_rhs_scale(y_fp8, y_scales, rhs_block):
|
||||
N, K = y_fp8.shape
|
||||
y_scales_original = y_scales.clone()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
y_hp = hp_from_1x128(y_fp8, y_scales_original)
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
y_hp = hp_from_128x128(y_fp8, y_scales_original)
|
||||
|
||||
return y_hp, rhs_recipe, y_scales, y_scales_original
|
||||
|
||||
def _build_lhs(x, lhs_block):
|
||||
M, K = x.shape
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
x_scales_original = x_scales
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
|
||||
return x_hp, x_recipe, x_fp8, x_scales, x_scales_original
|
||||
|
||||
def _build_rhs(y, rhs_block):
|
||||
N, K = y.shape
|
||||
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
|
||||
return y_hp, y_recipe, y_fp8, y_scales, y_scales_original
|
||||
|
||||
def _run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original):
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales.reciprocal(),
|
||||
scale_recipe_a=x_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales.reciprocal(),
|
||||
scale_recipe_b=y_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_emulated.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated.to(output_dtype), atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
def _build_constant_scale(t, block, val):
|
||||
M, K = t.shape
|
||||
|
||||
if block == 1:
|
||||
scale_shape = M, K // 128
|
||||
else:
|
||||
scale_shape = M // 128, K // 128
|
||||
|
||||
scale = torch.full(scale_shape, val, device='cuda')
|
||||
|
||||
return scale
|
||||
|
||||
def hp_to_scaled(t, scale, block):
|
||||
if block == 1:
|
||||
return hp_to_1x128(t, scale)
|
||||
else:
|
||||
return hp_to_128x128(t, scale)
|
||||
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
|
||||
if test_case == "x_eye_b_eye":
|
||||
if M != K or M != N:
|
||||
return unittest.skip("a_eye_b_eye only defined for M = N = K")
|
||||
x = torch.eye(M, device='cuda')
|
||||
y = torch.eye(M, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case == "x_ones_y_ones_calc_scales":
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case in ["x_ones_y_ones_set_scales", "x_ones_y_ones_modify_scales"]:
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
if "modify" in test_case:
|
||||
x_scales[0, 0] = 4.
|
||||
y_scales[-1, -1] = 4.
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_scales_one":
|
||||
x = torch.randint(0, 255, (M, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
y = torch.randint(0, 255, (N, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_calc_scales":
|
||||
# Note: Old test_scaled_mm_vs_emulated_block_wise test case
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype) * 1e-3
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
raise ValueError("Unknown test-case passed")
|
||||
|
||||
_run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@ -1335,18 +1558,30 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
x_scales_original = x_scales
|
||||
y_scales_original = y_scales
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
|
||||
# Verify that actual F8 mm doesn't error
|
||||
scaled_mm_wrap(
|
||||
@ -1354,13 +1589,20 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales,
|
||||
scale_recipe_a=lhs_recipe,
|
||||
scale_b=y_scales.t(),
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales,
|
||||
scale_recipe_b=rhs_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Verify that emulated F8 mm doesn't error
|
||||
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
|
||||
mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@ -1620,7 +1862,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1634,8 +1876,12 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -1813,7 +2059,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
@ -739,6 +739,12 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -179,6 +179,9 @@ def aot_stage1_graph_capture(
|
||||
)
|
||||
)
|
||||
|
||||
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
|
||||
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
|
||||
|
||||
return AOTGraphCapture(
|
||||
wrappers=wrappers,
|
||||
graph_module=graph,
|
||||
|
||||
@ -423,6 +423,10 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
|
||||
# nccl estimator requires real process group
|
||||
return None
|
||||
|
||||
fn = fx_node.target
|
||||
assert isinstance(fn, torch._ops.OpOverload)
|
||||
with torch.distributed._time_estimator(group=pg) as time_estimator:
|
||||
|
||||
@ -2318,7 +2318,7 @@ def compile_fx_forward(
|
||||
# force the outputs of invoke_subgraph subgraph to follow the
|
||||
# original strides
|
||||
_recursive_record_user_visible_output_idxs(gm)
|
||||
|
||||
print(f"in compile_fx_foward. static_input_idxs:{get_static_input_idxs(fixed)}")
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
|
||||
@ -24,6 +24,7 @@ from typing_extensions import Never, ParamSpec
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch._inductor.compile_worker.timer import Timer
|
||||
from torch._inductor.compile_worker.tracked_process_pool import (
|
||||
TrackedProcessPoolExecutor,
|
||||
)
|
||||
@ -132,6 +133,7 @@ class SubprocPool:
|
||||
nprocs: int,
|
||||
pickler: Optional[SubprocPickler] = None,
|
||||
kind: SubprocKind = SubprocKind.FORK,
|
||||
quiesce: bool = False,
|
||||
) -> None:
|
||||
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
|
||||
self.pickler = pickler or SubprocPickler()
|
||||
@ -216,6 +218,13 @@ class SubprocPool:
|
||||
"pytorch.wait_counter.subproc_pool.first_job"
|
||||
).guard()
|
||||
|
||||
if quiesce:
|
||||
self.timer: Optional[Timer] = Timer(
|
||||
config.quiesce_async_compile_time, self.quiesce
|
||||
)
|
||||
else:
|
||||
self.timer = None
|
||||
|
||||
# Start thread last to ensure all member variables are initialized
|
||||
# before any access.
|
||||
self.read_thread.start()
|
||||
@ -288,6 +297,8 @@ class SubprocPool:
|
||||
with self.futures_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.record_call()
|
||||
if isinstance(result, _SubprocExceptionInfo):
|
||||
# An exception occurred in the submitted job
|
||||
self.pending_futures[job_id].set_exception(
|
||||
@ -322,6 +333,8 @@ class SubprocPool:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.quit()
|
||||
self.running = False
|
||||
self.running_waitcounter.__exit__()
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
|
||||
@ -17,7 +17,7 @@ class Timer:
|
||||
self.background_thread: Optional[Thread] = None
|
||||
self.last_called: Optional[float] = None
|
||||
self.duration = duration
|
||||
self.sleep_time = 60
|
||||
self.sleep_time = duration / 2
|
||||
self.call = call
|
||||
self.exit = False
|
||||
|
||||
|
||||
@ -964,6 +964,11 @@ quiesce_async_compile_pool: bool = Config(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Time in seconds to wait before quiescing
|
||||
quiesce_async_compile_time: int = Config(
|
||||
default=60,
|
||||
)
|
||||
|
||||
# Whether or not to enable statically launching CUDA kernels
|
||||
# compiled by triton (instead of using triton's own launcher)
|
||||
use_static_cuda_launcher: bool = static_cuda_launcher_default()
|
||||
|
||||
@ -51,8 +51,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -549,6 +549,70 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
|
||||
@ -31,10 +31,8 @@ template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
@ -75,10 +73,8 @@ template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
@ -133,10 +129,8 @@ template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
@ -190,10 +184,8 @@ template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
@ -209,10 +201,8 @@ template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
@ -249,10 +239,8 @@ template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
@ -309,10 +297,8 @@ template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -350,10 +336,8 @@ template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
@ -4,12 +4,14 @@ r"""This package adds support for device memory management implemented in CUDA."
|
||||
import collections
|
||||
import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing_extensions import deprecated, NotRequired
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -29,6 +31,60 @@ if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
# Type definitions for memory profiler
|
||||
class _Frame(TypedDict):
|
||||
"""Frame information from memory profiler snapshots."""
|
||||
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Fields added by FX augmentation (optional)
|
||||
fx_node_op: NotRequired[str]
|
||||
fx_node_name: NotRequired[str]
|
||||
fx_node_target: NotRequired[str]
|
||||
fx_original_trace: NotRequired[str]
|
||||
|
||||
|
||||
class _Block(TypedDict):
|
||||
"""Memory block information."""
|
||||
|
||||
size: int
|
||||
requested_size: int
|
||||
address: int
|
||||
state: str
|
||||
frames: list[_Frame]
|
||||
|
||||
|
||||
class _Segment(TypedDict):
|
||||
"""Memory segment information."""
|
||||
|
||||
address: int
|
||||
total_size: int
|
||||
stream: int
|
||||
segment_type: str
|
||||
allocated_size: int
|
||||
active_size: int
|
||||
blocks: list[_Block]
|
||||
|
||||
|
||||
class _TraceEntry(TypedDict):
|
||||
"""Memory trace entry information."""
|
||||
|
||||
action: str
|
||||
addr: NotRequired[int]
|
||||
frames: list[_Frame]
|
||||
size: int
|
||||
stream: int
|
||||
device_free: NotRequired[int]
|
||||
|
||||
|
||||
class _Snapshot(TypedDict):
|
||||
"""Memory snapshot structure."""
|
||||
|
||||
segments: list[_Segment]
|
||||
device_traces: NotRequired[list[list[_TraceEntry]]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
@ -964,7 +1020,120 @@ def _record_memory_history_impl(
|
||||
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None):
|
||||
def _augment_frames(frames: list[_Frame]) -> int:
|
||||
"""
|
||||
Augment a list of frames with FX debug information.
|
||||
|
||||
Args:
|
||||
frames: List of frame dictionaries to augment
|
||||
|
||||
Returns:
|
||||
The count of frames that were augmented.
|
||||
"""
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
|
||||
# Regex pattern to match FX generated files
|
||||
_FX_GENERATED_PATTERN = re.compile(
|
||||
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
|
||||
)
|
||||
|
||||
count = 0
|
||||
if not frames:
|
||||
return count
|
||||
|
||||
for frame in frames:
|
||||
if "filename" in frame and "line" in frame:
|
||||
filename = frame["filename"]
|
||||
lineno = frame["line"]
|
||||
|
||||
# Check if this looks like an FX generated file
|
||||
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
|
||||
continue
|
||||
|
||||
# Look up metadata from the global registry
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
metadata = _FX_METADATA_REGISTRY.get(filename)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
lineno_map = metadata.get("lineno_map", {})
|
||||
node_metadata = metadata.get("node_metadata", {})
|
||||
prologue_start = metadata.get("prologue_start", 0)
|
||||
|
||||
# Get the node index for this line
|
||||
node_idx = lineno_map.get(lineno - prologue_start)
|
||||
|
||||
if node_idx is not None and node_idx in node_metadata:
|
||||
node_info = node_metadata[node_idx]
|
||||
original_trace = node_info.get("stack_trace")
|
||||
node_op = node_info.get("op")
|
||||
node_name = node_info.get("name")
|
||||
node_target = node_info.get("target")
|
||||
|
||||
# Always add node metadata
|
||||
frame["fx_node_op"] = node_op
|
||||
frame["fx_node_name"] = node_name
|
||||
frame["fx_node_target"] = str(node_target)
|
||||
|
||||
# Add original trace if available
|
||||
if original_trace:
|
||||
frame["fx_original_trace"] = original_trace
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def _augment_memory_snapshot_stack_traces(
|
||||
snapshot: str | _Snapshot,
|
||||
) -> _Snapshot:
|
||||
"""
|
||||
Augment a memory snapshot with original source stack traces from FX metadata.
|
||||
|
||||
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
|
||||
that is populated during graph module compilation. It must be called in the same
|
||||
Python process where the FX graphs were compiled. It cannot be used to augment
|
||||
snapshots loaded from disk in a different process.
|
||||
|
||||
Args:
|
||||
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
|
||||
|
||||
Returns:
|
||||
The augmented snapshot dictionary with fx_node_op, fx_node_name,
|
||||
fx_original_trace, and fx_node_info fields added to frames
|
||||
"""
|
||||
|
||||
snapshot_dict: _Snapshot
|
||||
if isinstance(snapshot, str):
|
||||
# Load the memory snapshot
|
||||
with open(snapshot, "rb") as f:
|
||||
snapshot_dict = cast(_Snapshot, pickle.load(f))
|
||||
else:
|
||||
snapshot_dict = snapshot
|
||||
|
||||
# Process stack traces in the snapshot
|
||||
augmented_count = 0
|
||||
|
||||
# Process blocks in segments (for regular allocations)
|
||||
if "segments" in snapshot_dict:
|
||||
for segment in snapshot_dict["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
augmented_count += _augment_frames(block["frames"])
|
||||
|
||||
# Process device traces (for memory history)
|
||||
if "device_traces" in snapshot_dict:
|
||||
for trace_list in snapshot_dict["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
augmented_count += _augment_frames(trace_entry["frames"])
|
||||
|
||||
return snapshot_dict
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
"""Save a snapshot of CUDA memory state at the time it was called.
|
||||
|
||||
The state is represented as a dictionary with the following structure.
|
||||
@ -1012,6 +1181,11 @@ def _snapshot(device: "Device" = None):
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Optional FX debug fields (present when augment_with_fx_traces=True
|
||||
# and the frame corresponds to FX-generated code)
|
||||
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
|
||||
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
|
||||
fx_original_trace: str # Original model source code stack trace
|
||||
|
||||
|
||||
class TraceEntry(TypedDict):
|
||||
@ -1041,13 +1215,23 @@ def _snapshot(device: "Device" = None):
|
||||
device_free: int # only present for OOM, the amount of
|
||||
# memory cuda still reports to be free
|
||||
|
||||
Args:
|
||||
device: Device to capture snapshot for. If None, captures for current device.
|
||||
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
|
||||
that maps generated FX code back to original model source code.
|
||||
This adds fx_node_op, fx_node_name, fx_original_trace, and
|
||||
fx_node_info fields to Frame objects. Default: False.
|
||||
|
||||
Returns:
|
||||
The Snapshot dictionary object
|
||||
"""
|
||||
return _C._cuda_memorySnapshot(None)
|
||||
s = _C._cuda_memorySnapshot(None)
|
||||
if augment_with_fx_traces:
|
||||
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
|
||||
return s
|
||||
|
||||
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
|
||||
"""
|
||||
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
||||
|
||||
@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
|
||||
Args:
|
||||
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
||||
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
|
||||
before dumping. This maps generated FX code stack traces
|
||||
back to original model source code. Defaults to False.
|
||||
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
|
||||
during augmentation. Defaults to False.
|
||||
"""
|
||||
s = _snapshot()
|
||||
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(s, f)
|
||||
|
||||
|
||||
@ -226,8 +226,10 @@ class PythonCode:
|
||||
# Values in global scope during execution of `src_def`.
|
||||
globals: dict[str, Any]
|
||||
# Optional mapping from the forward function's line number to
|
||||
# node index.
|
||||
# node index. Line number starts at the prologue (i.e. forward()).
|
||||
_lineno_map: Optional[dict[int, Optional[int]]]
|
||||
# The line number of prologue in fn_code
|
||||
_prologue_start: int = 0
|
||||
|
||||
|
||||
def _format_target(base: str, target: str) -> str:
|
||||
@ -854,7 +856,14 @@ class CodeGen:
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
|
||||
# The +4 accounts for the empty lines before prologue in fn_code
|
||||
prologue_start = wrap_stmts.count("\n") + 4
|
||||
return PythonCode(
|
||||
fn_code,
|
||||
globals_,
|
||||
_lineno_map=lineno_map,
|
||||
_prologue_start=prologue_start,
|
||||
)
|
||||
|
||||
|
||||
# Ideally, we'd like to refactor all of the pytree logic into this codegen
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import linecache
|
||||
import os
|
||||
@ -36,6 +38,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
|
||||
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
|
||||
|
||||
|
||||
# Normal exec loses the source code, however we can work with
|
||||
@ -61,7 +64,13 @@ class _EvalCacheLoader:
|
||||
|
||||
key = self._get_key()
|
||||
if co_fields:
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
if "co_filename" in co_fields:
|
||||
# If only co_filename is provided, use it directly as the key
|
||||
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
|
||||
key = co_fields["co_filename"]
|
||||
else:
|
||||
# Full co_fields with all three components
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
self.eval_cache[key] = src
|
||||
|
||||
# Don't mutate globals so that this loader is only used
|
||||
@ -353,6 +362,36 @@ def _print_readable(
|
||||
return output
|
||||
|
||||
|
||||
def _metadata_hash(code: str, node_metadata: dict) -> str:
|
||||
"""
|
||||
Create a content-addressed hash from code and metadata.
|
||||
|
||||
Args:
|
||||
code: The source code string
|
||||
lineno_map: Mapping from line numbers to node indices
|
||||
node_metadata: Metadata for each node
|
||||
|
||||
Returns:
|
||||
A 51-character base32-encoded hash
|
||||
"""
|
||||
import json
|
||||
|
||||
# Create a deterministic string representation of all components
|
||||
# We use JSON to ensure consistent serialization
|
||||
hash_data = {
|
||||
"code": code,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
hashing_str = json.dumps(hash_data).encode("utf-8")
|
||||
|
||||
# [:51] to strip off the "Q====" suffix common to every hash value.
|
||||
return (
|
||||
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
|
||||
.decode("utf-8")
|
||||
.lower()
|
||||
)
|
||||
|
||||
|
||||
class _WrappedCall:
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
@ -825,9 +864,47 @@ class {module_name}(torch.nn.Module):
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
node_metadata[i] = {
|
||||
"name": node.name,
|
||||
"op": node.op,
|
||||
"target": str(node.target),
|
||||
"stack_trace": node.meta.get("stack_trace", None),
|
||||
}
|
||||
|
||||
# Generate a content-addressed filename based on hash of code and metadata
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
co_fields = {
|
||||
"co_filename": filename,
|
||||
}
|
||||
|
||||
# Store metadata in global in-memory registry
|
||||
metadata = {
|
||||
"lineno_map": python_code._lineno_map,
|
||||
"prologue_start": python_code._prologue_start,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
|
||||
# Register metadata in the global registry
|
||||
from torch.fx.traceback import _register_fx_metadata
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -38,6 +38,28 @@ current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
# =============================================================================
|
||||
# FX Metadata Registry for Memory Profiler
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
This is called automatically during graph module compilation to store metadata
|
||||
for later use by memory profiler augmentation.
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class NodeSourceAction(Enum):
|
||||
|
||||
@ -806,7 +806,29 @@ function format_frames(frames) {
|
||||
}
|
||||
const frame_strings = frames
|
||||
.filter(frameFilter)
|
||||
.map(f => `${f.filename}:${f.line}:${f.name}`);
|
||||
.map(f => {
|
||||
let frame_str = `${f.filename}:${f.line}:${f.name}`;
|
||||
|
||||
// Add FX debug information if available
|
||||
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
|
||||
const fx_parts = [];
|
||||
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
|
||||
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
|
||||
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
|
||||
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
|
||||
}
|
||||
|
||||
if (f.fx_original_trace) {
|
||||
frame_str += `\n >> Original Model Code:`;
|
||||
const original_lines = f.fx_original_trace.trim().split('\n');
|
||||
// Show all lines of the original trace
|
||||
for (const line of original_lines) {
|
||||
frame_str += `\n ${line}`;
|
||||
}
|
||||
}
|
||||
|
||||
return frame_str;
|
||||
});
|
||||
return elideRepeats(frame_strings).join('\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user