mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 16:46:31 +08:00
Compare commits
1 Commits
whc/shardi
...
quote-pyte
| Author | SHA1 | Date | |
|---|---|---|---|
| a2200be9c7 |
@ -129,7 +129,7 @@ function install_129 {
|
||||
}
|
||||
|
||||
function install_128 {
|
||||
CUDNN_VERSION=9.10.2.21
|
||||
CUDNN_VERSION=9.8.0.87
|
||||
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
|
||||
# install CUDA 12.8.1 in the same container
|
||||
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
|
||||
|
||||
@ -272,18 +272,6 @@ def smoke_test_cuda(
|
||||
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
|
||||
print(f"Torch cuDNN version: {torch_cudnn_version}")
|
||||
|
||||
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
|
||||
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
|
||||
torch_cudnn_runtime_version = tuple(
|
||||
[int(x) for x in torch_cudnn_version.split(".")]
|
||||
)
|
||||
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
|
||||
raise RuntimeError(
|
||||
"cuDNN runtime version doesn't match comple version. "
|
||||
f"Loaded: {torch_cudnn_runtime_version} "
|
||||
f"Expected: {torch_cudnn_compile_version}"
|
||||
)
|
||||
|
||||
if sys.platform in ["linux", "linux2"]:
|
||||
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
|
||||
print(f"Torch nccl; version: {torch_nccl_version}")
|
||||
|
||||
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--download \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--repo $REPO \
|
||||
--bucket $BUCKET \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--repo "$REPO" \
|
||||
--bucket "$BUCKET" \
|
||||
|
||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--upload \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--sha $SHA \
|
||||
--test_config $TEST_CONFIG \
|
||||
--shard $SHARD \
|
||||
--repo $REPO \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--sha "$SHA" \
|
||||
--test_config "$TEST_CONFIG" \
|
||||
--shard "$SHARD" \
|
||||
--repo "$REPO" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
|
||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -97,8 +97,8 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
||||
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
|
||||
if [[ $ngpu -lt 4 ]]; then
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -39,6 +41,16 @@ namespace at {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -59,24 +59,6 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -1105,47 +1087,26 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1160,30 +1121,11 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1308,20 +1250,17 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1616,6 +1647,8 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1649,6 +1682,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -32,7 +32,6 @@ from torch.distributed.tensor._ops._einsum_strategy import (
|
||||
)
|
||||
from torch.distributed.tensor._ops.utils import (
|
||||
register_op_strategy,
|
||||
register_single_dim_strategy,
|
||||
replicate_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
@ -656,202 +655,5 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
|
||||
TestStrategyHashing,
|
||||
)
|
||||
|
||||
|
||||
class TestSingleDimStrategy(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_register_single_dim_strategy_replaces_existing_rule(self):
|
||||
"""
|
||||
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_mm_like_strategy,
|
||||
gen_single_dim_einsum_strategies,
|
||||
)
|
||||
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
# Create test inputs
|
||||
lhs_tensor = torch.randn(6, 8)
|
||||
rhs_tensor = torch.randn(8, 12)
|
||||
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
|
||||
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
|
||||
|
||||
# Test a specific input sharding combination
|
||||
lhs_placement = (Shard(1),)
|
||||
rhs_placement = (Shard(0),)
|
||||
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
|
||||
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
|
||||
|
||||
# Create the OpSchema for mm operation
|
||||
op_schema = OpSchema(
|
||||
torch.ops.aten.mm.default,
|
||||
(
|
||||
OpStrategy([OpSpec(lhs_spec)]),
|
||||
OpStrategy([OpSpec(rhs_spec)]),
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Get the strategies from the old mm_like_strategy (what was used before)
|
||||
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
# Get the strategies from the new register_single_dim_strategy approach
|
||||
# First, we need to get the single dim strategy function
|
||||
def mm_single_dim_strategy_func(op_schema: OpSchema):
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Now expand it to full strategy using the same logic as register_single_dim_strategy
|
||||
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
|
||||
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
all_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
)
|
||||
new_style_strategy = OpStrategy(all_strategies)
|
||||
|
||||
# Verify that both strategies produce the same set of shardings
|
||||
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
|
||||
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
|
||||
|
||||
self.assertEqual(
|
||||
old_strategy_set,
|
||||
new_strategy_set,
|
||||
"Old and new strategies should produce the same shardings",
|
||||
)
|
||||
|
||||
# Verify that the registration actually works by checking the propagator
|
||||
propagator = DTensor._op_dispatcher.sharding_propagator
|
||||
|
||||
# Save the original strategy if it exists
|
||||
original_strategy = None
|
||||
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
|
||||
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
|
||||
try:
|
||||
# Register a custom single-dim strategy
|
||||
@register_single_dim_strategy(torch.ops.aten.mm.default)
|
||||
def custom_mm_single_dim_strategy(op_schema: OpSchema):
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Verify the strategy was registered
|
||||
self.assertIn(
|
||||
torch.ops.aten.mm.default,
|
||||
propagator.op_strategy_funcs,
|
||||
"Strategy should be registered after calling register_single_dim_strategy",
|
||||
)
|
||||
|
||||
# Verify it replaced any existing rule
|
||||
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
self.assertIsNotNone(
|
||||
registered_func, "Registered strategy function should not be None"
|
||||
)
|
||||
|
||||
# Test that the registered strategy produces valid output
|
||||
result_strategy = registered_func(op_schema)
|
||||
self.assertIsInstance(
|
||||
result_strategy, OpStrategy, "Result should be an OpStrategy"
|
||||
)
|
||||
self.assertGreater(
|
||||
len(result_strategy.strategies),
|
||||
0,
|
||||
"Strategy should contain at least one OpSpec",
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original strategy if it existed
|
||||
if original_strategy is not None:
|
||||
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
|
||||
original_strategy
|
||||
)
|
||||
else:
|
||||
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
|
||||
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
# Clear the cache
|
||||
propagator.propagate_op_sharding.cache.cache_clear()
|
||||
|
||||
@with_comms
|
||||
def test_single_dim_strategy_shardings_match_full_strategy(self):
|
||||
"""
|
||||
Verify that the shardings produced by a single-dim strategy match those produced
|
||||
by the full strategy implementation.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
gen_single_dim_einsum_strategies,
|
||||
)
|
||||
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
# Create test inputs
|
||||
lhs_tensor = torch.randn(6, 8)
|
||||
rhs_tensor = torch.randn(8, 12)
|
||||
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
|
||||
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
|
||||
|
||||
# Test multiple input sharding combinations
|
||||
mm_combs = (
|
||||
(Shard(0), Replicate()),
|
||||
(Replicate(), Shard(1)),
|
||||
(Shard(1), Shard(0)),
|
||||
(Replicate(), Replicate()),
|
||||
)
|
||||
|
||||
for lhs_placement, rhs_placement in mm_combs:
|
||||
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
|
||||
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
|
||||
|
||||
op_schema = OpSchema(
|
||||
torch.ops.aten.mm.default,
|
||||
(
|
||||
OpStrategy([OpSpec(lhs_spec)]),
|
||||
OpStrategy([OpSpec(rhs_spec)]),
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Get single-dim strategies
|
||||
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Expand to full strategy (mimicking what register_single_dim_strategy does)
|
||||
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
expanded_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
expanded_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
)
|
||||
|
||||
# Verify that for the given input shardings, we can find a matching strategy
|
||||
# with zero redistribute cost
|
||||
found_zero_cost_strategy = False
|
||||
for strategy in expanded_strategies:
|
||||
if strategy.input_specs == (lhs_spec, rhs_spec):
|
||||
# This strategy should have zero redistribute cost since inputs match
|
||||
found_zero_cost_strategy = True
|
||||
# In a real strategy, redistribute costs would be computed
|
||||
# Here we just verify the structure is correct
|
||||
self.assertEqual(
|
||||
len(strategy.input_specs),
|
||||
2,
|
||||
"MM should have exactly 2 input specs",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
strategy.output_specs, "Output spec should not be None"
|
||||
)
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
found_zero_cost_strategy,
|
||||
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -167,14 +167,6 @@ def _pack_fp8_wrap(x):
|
||||
if not x.dtype.is_floating_point:
|
||||
return x
|
||||
|
||||
if type(x) is not torch.Tensor:
|
||||
# Check only during compilation
|
||||
# Test calls hooks to get reference output
|
||||
ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context()
|
||||
assert ctx["_fw_graph"] is not None
|
||||
assert ctx["_bw_graph"] is not None
|
||||
assert ctx["_node"] is not None
|
||||
|
||||
return (x.dtype, x.to(torch.float8_e5m2))
|
||||
|
||||
|
||||
@ -184,13 +176,6 @@ def _unpack_fp8_wrap(x):
|
||||
return x
|
||||
|
||||
dtype, tensor = x
|
||||
if type(tensor) is not torch.Tensor:
|
||||
# Check only during compilation
|
||||
# Test calls hooks to get reference output
|
||||
ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context()
|
||||
assert ctx["_fw_graph"] is not None
|
||||
assert ctx["_bw_graph"] is not None
|
||||
assert ctx["_node"] is not None
|
||||
return tensor.to(dtype)
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -17,12 +16,9 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -33,7 +29,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -46,7 +42,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -62,7 +58,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -79,7 +75,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -87,12 +83,6 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -117,22 +117,6 @@ class MixOrderReductionTest(TestBase):
|
||||
metrics.codegen_mix_order_reduction,
|
||||
)
|
||||
|
||||
@inductor_config.patch(coordinate_descent_tuning=True)
|
||||
def test_XBLOCK_coordest_tuning(self):
|
||||
"""
|
||||
We should skip XBLOCK coordinate descent tuning for
|
||||
mix order reduction.
|
||||
"""
|
||||
if not inductor_config.triton.mix_order_reduction:
|
||||
self.skipTest("Mix order reduction not enabled")
|
||||
|
||||
def f(x):
|
||||
return x.sum(dim=-1), x.sum(dim=0)
|
||||
|
||||
x = torch.randn(32768, 256, dtype=torch.float, device=GPU_TYPE)
|
||||
self.check_numeric(f, (x,))
|
||||
self.assertEqual(metrics.codegen_mix_order_reduction, 1)
|
||||
|
||||
@inductor_config.patch(unroll_reductions_threshold=1)
|
||||
def test_3layer_split_reduction(self):
|
||||
"""
|
||||
|
||||
@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
softmax_load_str in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -15280,7 +15280,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15293,7 +15293,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1826,14 +1826,9 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -7413,140 +7413,6 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,7 +771,6 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -781,7 +780,6 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -209,36 +209,42 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX formats
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MX w/swizzle (NVIDIA)
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MX w/o swizzle (AMD)
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1862,7 +1868,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1876,12 +1882,8 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -2059,7 +2061,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
@ -739,12 +739,6 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -25,9 +25,6 @@ from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
@ -100,43 +97,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
# Saved tensor hooks context
|
||||
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
|
||||
# for saved nodes from forward to backward. (E.g. activations quantization)
|
||||
# In base implementation user does not have any additional information about saved value
|
||||
# in the hook, except FakeTensor shape, dtype, device etc.
|
||||
# _get_saved_tensor_hook_context gives additional graph information about that saved value,
|
||||
# that can be used to make a decisions which pack/unpack to apply for particular saved value.
|
||||
# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in
|
||||
# graph aware way.
|
||||
# Alternative to this will be making user to write a custom pass that mucks with forward outputs,
|
||||
# backward input metadata, which requires significantly more effort.
|
||||
#
|
||||
# As for now in context we expose forward graph, backward graph and current saved node,
|
||||
# which contains node.meta with additional information about that fx.Node.
|
||||
# Warning: This API may change without backward compatibility.
|
||||
@contextmanager
|
||||
def _saved_tensor_hook_context(state: dict[str, Any]):
|
||||
previous_state = getattr(_thread_local, "state", None)
|
||||
try:
|
||||
_thread_local.state = state
|
||||
yield
|
||||
finally:
|
||||
# Clean up: restore previous state or remove attribute
|
||||
if previous_state is not None:
|
||||
_thread_local.state = previous_state
|
||||
else:
|
||||
if hasattr(_thread_local, "state"):
|
||||
delattr(_thread_local, "state")
|
||||
|
||||
|
||||
def _get_saved_tensor_hook_context() -> dict[str, Any] | None:
|
||||
return getattr(_thread_local, "state", None)
|
||||
|
||||
|
||||
zip = strict_zip
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -1137,11 +1097,7 @@ def maybe_inline_graph_saved_tensors_hooks(
|
||||
if not isinstance(val, torch.Tensor):
|
||||
continue
|
||||
|
||||
def _get_extra_info() -> dict[str, Any]:
|
||||
return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved}
|
||||
|
||||
with _saved_tensor_hook_context(_get_extra_info()):
|
||||
pack_out_val = pack_hook_gm(val)
|
||||
pack_out_val = pack_hook_gm(val)
|
||||
|
||||
requires_sc_handling = any(
|
||||
is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val)
|
||||
@ -1153,17 +1109,16 @@ def maybe_inline_graph_saved_tensors_hooks(
|
||||
" in the pack hook, and reconstructing the subclass in the unpack hook"
|
||||
)
|
||||
|
||||
with _saved_tensor_hook_context(_get_extra_info()):
|
||||
pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,))
|
||||
pack_g = pack_gm.graph
|
||||
maybe_log_graph(
|
||||
pack_gm,
|
||||
f"saved_tensors_pack_hook {saved.name}",
|
||||
aot_config,
|
||||
lambda: f"aot_saved_tensors_hooks_pack {saved.name}",
|
||||
structured_logs,
|
||||
)
|
||||
pack_out_val = pack_gm(val)
|
||||
pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,))
|
||||
pack_g = pack_gm.graph
|
||||
maybe_log_graph(
|
||||
pack_gm,
|
||||
f"saved_tensors_pack_hook {saved.name}",
|
||||
aot_config,
|
||||
lambda: f"aot_saved_tensors_hooks_pack {saved.name}",
|
||||
structured_logs,
|
||||
)
|
||||
pack_out_val = pack_gm(val)
|
||||
|
||||
# Install pack hook graph as eiplogue of fw_module.
|
||||
# Saved tensor output becomes input of pack hook graph.
|
||||
@ -1233,16 +1188,15 @@ def maybe_inline_graph_saved_tensors_hooks(
|
||||
# Install unpack hook graph as a prologue of backward graph
|
||||
# Saved tensors inputs are replaced with packed tensors and packed sym scalars.
|
||||
# The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs.
|
||||
with _saved_tensor_hook_context(_get_extra_info()):
|
||||
unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,))
|
||||
unpack_g = unpack_gm.graph
|
||||
maybe_log_graph(
|
||||
unpack_gm,
|
||||
f"saved_tensors_unpack_hook {saved.name}",
|
||||
aot_config,
|
||||
lambda: f"aot_saved_tensors_hooks_unpack {saved.name}",
|
||||
structured_logs,
|
||||
)
|
||||
unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,))
|
||||
unpack_g = unpack_gm.graph
|
||||
maybe_log_graph(
|
||||
unpack_gm,
|
||||
f"saved_tensors_unpack_hook {saved.name}",
|
||||
aot_config,
|
||||
lambda: f"aot_saved_tensors_hooks_unpack {saved.name}",
|
||||
structured_logs,
|
||||
)
|
||||
|
||||
def find_saved_in_bw_inputs(bw_inputs):
|
||||
for n in bw_inputs:
|
||||
|
||||
@ -498,7 +498,6 @@ def generate_ttir(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
# pyrefly: ignore[missing-argument,bad-argument-type]
|
||||
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
|
||||
else:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
|
||||
@ -423,10 +423,6 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
|
||||
# nccl estimator requires real process group
|
||||
return None
|
||||
|
||||
fn = fx_node.target
|
||||
assert isinstance(fn, torch._ops.OpOverload)
|
||||
with torch.distributed._time_estimator(group=pg) as time_estimator:
|
||||
|
||||
@ -24,7 +24,6 @@ from typing_extensions import Never, ParamSpec
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch._inductor.compile_worker.timer import Timer
|
||||
from torch._inductor.compile_worker.tracked_process_pool import (
|
||||
TrackedProcessPoolExecutor,
|
||||
)
|
||||
@ -133,7 +132,6 @@ class SubprocPool:
|
||||
nprocs: int,
|
||||
pickler: Optional[SubprocPickler] = None,
|
||||
kind: SubprocKind = SubprocKind.FORK,
|
||||
quiesce: bool = False,
|
||||
) -> None:
|
||||
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
|
||||
self.pickler = pickler or SubprocPickler()
|
||||
@ -218,13 +216,6 @@ class SubprocPool:
|
||||
"pytorch.wait_counter.subproc_pool.first_job"
|
||||
).guard()
|
||||
|
||||
if quiesce:
|
||||
self.timer: Optional[Timer] = Timer(
|
||||
config.quiesce_async_compile_time, self.quiesce
|
||||
)
|
||||
else:
|
||||
self.timer = None
|
||||
|
||||
# Start thread last to ensure all member variables are initialized
|
||||
# before any access.
|
||||
self.read_thread.start()
|
||||
@ -297,8 +288,6 @@ class SubprocPool:
|
||||
with self.futures_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.record_call()
|
||||
if isinstance(result, _SubprocExceptionInfo):
|
||||
# An exception occurred in the submitted job
|
||||
self.pending_futures[job_id].set_exception(
|
||||
@ -333,8 +322,6 @@ class SubprocPool:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.quit()
|
||||
self.running = False
|
||||
self.running_waitcounter.__exit__()
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
|
||||
@ -17,7 +17,7 @@ class Timer:
|
||||
self.background_thread: Optional[Thread] = None
|
||||
self.last_called: Optional[float] = None
|
||||
self.duration = duration
|
||||
self.sleep_time = duration / 2
|
||||
self.sleep_time = 60
|
||||
self.call = call
|
||||
self.exit = False
|
||||
|
||||
|
||||
@ -546,6 +546,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
@ -960,11 +964,6 @@ quiesce_async_compile_pool: bool = Config(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Time in seconds to wait before quiescing
|
||||
quiesce_async_compile_time: int = Config(
|
||||
default=60,
|
||||
)
|
||||
|
||||
# Whether or not to enable statically launching CUDA kernels
|
||||
# compiled by triton (instead of using triton's own launcher)
|
||||
use_static_cuda_launcher: bool = static_cuda_launcher_default()
|
||||
|
||||
@ -51,8 +51,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1510,10 +1510,8 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
@ -18,19 +19,25 @@ from ..select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
ensure_cute_available,
|
||||
get_gpu_shared_memory,
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
|
||||
if ensure_cute_available():
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -513,6 +520,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -714,43 +726,44 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -788,6 +801,22 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
@ -0,0 +1,333 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
@ -5,8 +5,6 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .hints import TRITON_MAX_BLOCK
|
||||
from .runtime_utils import red_text, triton_config_to_hashable
|
||||
|
||||
@ -56,7 +54,6 @@ class CoordescTuner:
|
||||
name="unknown",
|
||||
size_hints=None,
|
||||
inductor_meta=None,
|
||||
frozen_fields=None,
|
||||
):
|
||||
self.is_mm = is_mm # we will tune num_stages for mm
|
||||
|
||||
@ -69,9 +66,6 @@ class CoordescTuner:
|
||||
self.name = name
|
||||
self.size_hints = size_hints
|
||||
self.inductor_meta = inductor_meta or {}
|
||||
self.frozen_fields: OrderedSet[str] = (
|
||||
OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet()
|
||||
)
|
||||
|
||||
def get_config_max(self, prefix: str) -> int:
|
||||
max_block = TRITON_MAX_BLOCK[prefix.upper()]
|
||||
@ -123,7 +117,7 @@ class CoordescTuner:
|
||||
out.append("num_stages")
|
||||
out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul
|
||||
|
||||
return [f for f in out if f not in self.frozen_fields]
|
||||
return out
|
||||
|
||||
def value_too_large(self, name: str, val: int) -> bool:
|
||||
block_suffix = "BLOCK"
|
||||
|
||||
@ -336,7 +336,6 @@ class CachingAutotuner(KernelInterface):
|
||||
name=self.fn.__name__,
|
||||
size_hints=size_hints,
|
||||
inductor_meta=self.inductor_meta,
|
||||
frozen_fields=self.get_coordesc_frozen_fields(),
|
||||
)
|
||||
self.filename = filename
|
||||
|
||||
@ -366,13 +365,6 @@ class CachingAutotuner(KernelInterface):
|
||||
# Mode for launch grid calculation
|
||||
self.grid_mode: Literal["python", "cpp"] = "python"
|
||||
|
||||
def get_coordesc_frozen_fields(self) -> OrderedSet[str]:
|
||||
out: OrderedSet[str] = OrderedSet()
|
||||
if self.inductor_meta.get("RSPLIT_SIZE"):
|
||||
# We fix XBLOCK for mix order reduction
|
||||
out.add("XBLOCK")
|
||||
return out
|
||||
|
||||
def is_statically_launchable(self):
|
||||
"""
|
||||
Checks if every compiled kernel is statically launchable, which
|
||||
|
||||
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
@ -0,0 +1,141 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -549,70 +549,6 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
@ -1975,6 +1911,77 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL is available
|
||||
2. We are on a blackwell arch
|
||||
3. The dtype is bf16
|
||||
4. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -31,8 +31,10 @@ template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
@ -73,8 +75,10 @@ template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
@ -129,8 +133,10 @@ template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
@ -184,8 +190,10 @@ template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
@ -201,8 +209,10 @@ template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
@ -239,8 +249,10 @@ template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
@ -297,8 +309,10 @@ template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -336,8 +350,10 @@ template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None):
|
||||
"nvidia-ml-py does not seem to be installed or it can't be imported."
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
) from _PYNVML_ERR
|
||||
# pyrefly: ignore [import-error,missing-module-attribute]
|
||||
# pyrefly: ignore [import-error]
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
|
||||
@ -4,14 +4,12 @@ r"""This package adds support for device memory management implemented in CUDA."
|
||||
import collections
|
||||
import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing_extensions import deprecated, NotRequired
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -31,60 +29,6 @@ if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
# Type definitions for memory profiler
|
||||
class _Frame(TypedDict):
|
||||
"""Frame information from memory profiler snapshots."""
|
||||
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Fields added by FX augmentation (optional)
|
||||
fx_node_op: NotRequired[str]
|
||||
fx_node_name: NotRequired[str]
|
||||
fx_node_target: NotRequired[str]
|
||||
fx_original_trace: NotRequired[str]
|
||||
|
||||
|
||||
class _Block(TypedDict):
|
||||
"""Memory block information."""
|
||||
|
||||
size: int
|
||||
requested_size: int
|
||||
address: int
|
||||
state: str
|
||||
frames: list[_Frame]
|
||||
|
||||
|
||||
class _Segment(TypedDict):
|
||||
"""Memory segment information."""
|
||||
|
||||
address: int
|
||||
total_size: int
|
||||
stream: int
|
||||
segment_type: str
|
||||
allocated_size: int
|
||||
active_size: int
|
||||
blocks: list[_Block]
|
||||
|
||||
|
||||
class _TraceEntry(TypedDict):
|
||||
"""Memory trace entry information."""
|
||||
|
||||
action: str
|
||||
addr: NotRequired[int]
|
||||
frames: list[_Frame]
|
||||
size: int
|
||||
stream: int
|
||||
device_free: NotRequired[int]
|
||||
|
||||
|
||||
class _Snapshot(TypedDict):
|
||||
"""Memory snapshot structure."""
|
||||
|
||||
segments: list[_Segment]
|
||||
device_traces: NotRequired[list[list[_TraceEntry]]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
@ -828,7 +772,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
|
||||
import pynvml # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
return "pynvml module not found, please install nvidia-ml-py"
|
||||
# pyrefly: ignore [import-error,missing-module-attribute]
|
||||
# pyrefly: ignore [import-error]
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
@ -1020,120 +964,7 @@ def _record_memory_history_impl(
|
||||
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _augment_frames(frames: list[_Frame]) -> int:
|
||||
"""
|
||||
Augment a list of frames with FX debug information.
|
||||
|
||||
Args:
|
||||
frames: List of frame dictionaries to augment
|
||||
|
||||
Returns:
|
||||
The count of frames that were augmented.
|
||||
"""
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
|
||||
# Regex pattern to match FX generated files
|
||||
_FX_GENERATED_PATTERN = re.compile(
|
||||
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
|
||||
)
|
||||
|
||||
count = 0
|
||||
if not frames:
|
||||
return count
|
||||
|
||||
for frame in frames:
|
||||
if "filename" in frame and "line" in frame:
|
||||
filename = frame["filename"]
|
||||
lineno = frame["line"]
|
||||
|
||||
# Check if this looks like an FX generated file
|
||||
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
|
||||
continue
|
||||
|
||||
# Look up metadata from the global registry
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
metadata = _FX_METADATA_REGISTRY.get(filename)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
lineno_map = metadata.get("lineno_map", {})
|
||||
node_metadata = metadata.get("node_metadata", {})
|
||||
prologue_start = metadata.get("prologue_start", 0)
|
||||
|
||||
# Get the node index for this line
|
||||
node_idx = lineno_map.get(lineno - prologue_start)
|
||||
|
||||
if node_idx is not None and node_idx in node_metadata:
|
||||
node_info = node_metadata[node_idx]
|
||||
original_trace = node_info.get("stack_trace")
|
||||
node_op = node_info.get("op")
|
||||
node_name = node_info.get("name")
|
||||
node_target = node_info.get("target")
|
||||
|
||||
# Always add node metadata
|
||||
frame["fx_node_op"] = node_op
|
||||
frame["fx_node_name"] = node_name
|
||||
frame["fx_node_target"] = str(node_target)
|
||||
|
||||
# Add original trace if available
|
||||
if original_trace:
|
||||
frame["fx_original_trace"] = original_trace
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def _augment_memory_snapshot_stack_traces(
|
||||
snapshot: str | _Snapshot,
|
||||
) -> _Snapshot:
|
||||
"""
|
||||
Augment a memory snapshot with original source stack traces from FX metadata.
|
||||
|
||||
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
|
||||
that is populated during graph module compilation. It must be called in the same
|
||||
Python process where the FX graphs were compiled. It cannot be used to augment
|
||||
snapshots loaded from disk in a different process.
|
||||
|
||||
Args:
|
||||
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
|
||||
|
||||
Returns:
|
||||
The augmented snapshot dictionary with fx_node_op, fx_node_name,
|
||||
fx_original_trace, and fx_node_info fields added to frames
|
||||
"""
|
||||
|
||||
snapshot_dict: _Snapshot
|
||||
if isinstance(snapshot, str):
|
||||
# Load the memory snapshot
|
||||
with open(snapshot, "rb") as f:
|
||||
snapshot_dict = cast(_Snapshot, pickle.load(f))
|
||||
else:
|
||||
snapshot_dict = snapshot
|
||||
|
||||
# Process stack traces in the snapshot
|
||||
augmented_count = 0
|
||||
|
||||
# Process blocks in segments (for regular allocations)
|
||||
if "segments" in snapshot_dict:
|
||||
for segment in snapshot_dict["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
augmented_count += _augment_frames(block["frames"])
|
||||
|
||||
# Process device traces (for memory history)
|
||||
if "device_traces" in snapshot_dict:
|
||||
for trace_list in snapshot_dict["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
augmented_count += _augment_frames(trace_entry["frames"])
|
||||
|
||||
return snapshot_dict
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
def _snapshot(device: "Device" = None):
|
||||
"""Save a snapshot of CUDA memory state at the time it was called.
|
||||
|
||||
The state is represented as a dictionary with the following structure.
|
||||
@ -1181,11 +1012,6 @@ def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Optional FX debug fields (present when augment_with_fx_traces=True
|
||||
# and the frame corresponds to FX-generated code)
|
||||
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
|
||||
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
|
||||
fx_original_trace: str # Original model source code stack trace
|
||||
|
||||
|
||||
class TraceEntry(TypedDict):
|
||||
@ -1215,23 +1041,13 @@ def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
device_free: int # only present for OOM, the amount of
|
||||
# memory cuda still reports to be free
|
||||
|
||||
Args:
|
||||
device: Device to capture snapshot for. If None, captures for current device.
|
||||
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
|
||||
that maps generated FX code back to original model source code.
|
||||
This adds fx_node_op, fx_node_name, fx_original_trace, and
|
||||
fx_node_info fields to Frame objects. Default: False.
|
||||
|
||||
Returns:
|
||||
The Snapshot dictionary object
|
||||
"""
|
||||
s = _C._cuda_memorySnapshot(None)
|
||||
if augment_with_fx_traces:
|
||||
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
|
||||
return s
|
||||
return _C._cuda_memorySnapshot(None)
|
||||
|
||||
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
"""
|
||||
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
||||
|
||||
@ -1243,14 +1059,8 @@ def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False
|
||||
|
||||
Args:
|
||||
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
||||
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
|
||||
before dumping. This maps generated FX code stack traces
|
||||
back to original model source code. Defaults to False.
|
||||
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
|
||||
during augmentation. Defaults to False.
|
||||
"""
|
||||
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
|
||||
|
||||
s = _snapshot()
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(s, f)
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ from torch.distributed.tensor._ops.utils import (
|
||||
map_placements_after_broadcast,
|
||||
prod,
|
||||
register_op_strategy,
|
||||
register_single_dim_strategy,
|
||||
)
|
||||
from torch.distributed.tensor._utils import (
|
||||
compute_local_shape_and_global_offset,
|
||||
@ -238,130 +237,10 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
return _mm_like_strategy("i,i->", mesh, op_schema)
|
||||
|
||||
|
||||
# @register_op_strategy(aten.mm.default)
|
||||
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# mesh = op_schema.get_mesh_from_args()
|
||||
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
|
||||
from ._einsum_strategy import EinsumDims
|
||||
|
||||
|
||||
def gen_single_dim_einsum_strategies(
|
||||
equation: str,
|
||||
mesh: DeviceMesh,
|
||||
*,
|
||||
linearity: bool = False,
|
||||
) -> list[Placement]:
|
||||
"""
|
||||
Generate a strategy list for the ops that follow einsum style notation.
|
||||
|
||||
In principle, each mesh dim is independent of other device mesh dim when we
|
||||
generate strategies. So we generate strategy over each device mesh dim and
|
||||
do product combination on all mesh dims. We basically follow the below rule
|
||||
for each device mesh dim:
|
||||
|
||||
1. Shard on contracting dim: When both inputs shard on contracting dim over
|
||||
the same device dim. The result will be Partial over that device dim.
|
||||
|
||||
2. Shard on noncontracting dim:
|
||||
2.1: Shard on batch dim: output, both inputs all should shard on batch
|
||||
dim.
|
||||
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
|
||||
input should shard on this free dim.
|
||||
|
||||
3. Linearity (Partial): If enabled, set Partial on output and inputs over
|
||||
the same device mesh dim.
|
||||
"""
|
||||
# parse einop equation and extract dims
|
||||
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
||||
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
||||
all_mesh_dim_strategies = []
|
||||
|
||||
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
|
||||
strategies_over_one_mesh_dim = []
|
||||
|
||||
# placement list stores placements of [output, input1, input2, ...]
|
||||
# first we always have replicate all for inputs and output
|
||||
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split batch dim
|
||||
for batch_dim in edims.batch_dims:
|
||||
output_batch_dim = output_dim.index(batch_dim)
|
||||
placement_list = [Shard(output_batch_dim)]
|
||||
for input_dim in input_dims:
|
||||
input_batch_dim = input_dim.index(batch_dim)
|
||||
placement_list.append(Shard(input_batch_dim))
|
||||
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split contracting dim
|
||||
for contracting_dim in edims.contracting_dims:
|
||||
# Contracting dim can shard on same device axis for both inputs. This
|
||||
# results in the output being Partial on that device axis. For example:
|
||||
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
|
||||
placement_list = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
input_contracting_dim = input_dim.index(contracting_dim)
|
||||
placement_list.append(Shard(input_contracting_dim))
|
||||
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split lhs free dim
|
||||
for lhs_dim in edims.lhs_out_only_dims:
|
||||
lhs_free_dim_output = output_dim.index(lhs_dim)
|
||||
lhs_free_dim_input = input_dims[0].index(lhs_dim)
|
||||
# this means split the lhs input and output
|
||||
# i.e. S(0), R -> S(0)
|
||||
lhs_placement_list: list[Placement] = [
|
||||
Shard(lhs_free_dim_output),
|
||||
Shard(lhs_free_dim_input),
|
||||
Replicate(),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(lhs_placement_list)
|
||||
|
||||
# split rhs free dim
|
||||
for rhs_dim in edims.rhs_out_only_dims:
|
||||
rhs_free_dim_output = output_dim.index(rhs_dim)
|
||||
rhs_free_dim_input = input_dims[1].index(rhs_dim)
|
||||
rhs_placement_list: list[Placement] = [
|
||||
Shard(rhs_free_dim_output),
|
||||
Replicate(),
|
||||
Shard(rhs_free_dim_input),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(rhs_placement_list)
|
||||
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for _ in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
strategies_over_one_mesh_dim.append(linearity_placement_list)
|
||||
|
||||
# generate strategies for entire mesh
|
||||
# all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
|
||||
# strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
# all_strategies = []
|
||||
# for strategy_comb in strategy_combs:
|
||||
# spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
|
||||
# strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
# all_strategies.append(strat)
|
||||
|
||||
# return OpStrategy(all_strategies)
|
||||
return strategies_over_one_mesh_dim
|
||||
|
||||
|
||||
@register_single_dim_strategy(aten.mm.default)
|
||||
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
|
||||
self_strategy, mat2_strategy = op_schema.args_schema
|
||||
if not isinstance(self_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
|
||||
if not isinstance(mat2_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
|
||||
# generate all possible strategies for mm
|
||||
@register_op_strategy(aten.mm.default)
|
||||
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
|
||||
@register_op_strategy(aten.addmm.default)
|
||||
|
||||
@ -41,8 +41,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
|
||||
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
|
||||
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# For ops with a single tensor input, we perform a 1:1 mapping such that
|
||||
# for each strategy that the input supports, we create a corresponding strategy.
|
||||
@ -99,28 +97,6 @@ register_op_strategy(
|
||||
)(propagate_single_input_strategy)
|
||||
|
||||
|
||||
"""
|
||||
WHC- equal_strategy is an example baking an optimization into the sharding rule.
|
||||
|
||||
The unoptimized equal strategy (for one mesh dim) should look like this
|
||||
S, S -> S
|
||||
R, R -> R
|
||||
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
|
||||
And this should be expanded to the full mesh.
|
||||
|
||||
But what this rule actually does is
|
||||
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
|
||||
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
|
||||
Why? becuase converting the other arg from R->S is cheaper than converting S->R
|
||||
|
||||
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
|
||||
- then adjust the placements from partial to replicate since we don't support partial in equal
|
||||
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
|
||||
|
||||
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten.equal.default,
|
||||
@ -164,19 +140,6 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
return equal_strategy
|
||||
|
||||
|
||||
"""
|
||||
WHC
|
||||
seems like we could replace this with single-mesh strategy
|
||||
S->S
|
||||
R->R
|
||||
P->R
|
||||
|
||||
The P->R thing is odd, but makes sense:
|
||||
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
|
||||
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten.empty_like.default,
|
||||
@ -518,19 +481,6 @@ def replicate_tensor_dim(
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
|
||||
|
||||
seems very simple to write this way
|
||||
|
||||
assert input, src same ndim
|
||||
for i in range(input.ndim):
|
||||
if i != slice_dim:
|
||||
Shard(i), Shard(i) -> Shard(i)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
|
||||
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# 1. number of dimensions in input and src need to match.
|
||||
|
||||
@ -4,7 +4,8 @@ import functools
|
||||
import itertools
|
||||
import operator
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import cast, Optional, Union
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DimsSequenceType, DimsType
|
||||
@ -29,6 +30,10 @@ from torch.distributed.tensor.placement_types import (
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
# convenient wrapper to register sharding propagation rules
|
||||
def register_prop_rule(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
@ -49,61 +54,11 @@ def register_prop_rule(
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_single_dim_strategy(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> Callable[
|
||||
[Callable[[OpSchema], list[Placement]]], Callable[[OpSchema], StrategyType]
|
||||
]:
|
||||
"""
|
||||
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
|
||||
to cover all the mesh dims present in the runtime inputs.
|
||||
"""
|
||||
|
||||
def expanded_registration_wrapper(
|
||||
single_dim_strategy: Callable[[OpSchema], list[Placement]],
|
||||
) -> Callable[[OpSchema], StrategyType]:
|
||||
def _expanded_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
"""
|
||||
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
|
||||
sharding types used by inputs.
|
||||
"""
|
||||
inputs_strategy = op_schema.args_strategy
|
||||
mesh = inputs_strategy[0].mesh
|
||||
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
|
||||
|
||||
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
|
||||
# TODO: filter out 'invalid' placements
|
||||
# - ShardVar needs to say whether 'even sharding' is required or not
|
||||
|
||||
# copied from einsum strategy..
|
||||
# TODO: identify differences between this and 'expand_' util
|
||||
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
all_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
)
|
||||
|
||||
return OpStrategy(all_strategies)
|
||||
|
||||
# register_op_strategy returns another wrapper that actually does the strategy registration,
|
||||
# we just add another layer of wrapping that expands the single_dim_strategy into a strategy that's
|
||||
# compatible with register_op_strategy
|
||||
register_op_strategy(op, schema_info)(_expanded_strategy)
|
||||
return _expanded_strategy
|
||||
|
||||
return expanded_registration_wrapper
|
||||
|
||||
|
||||
def register_op_strategy(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
|
||||
op, schema_info=None
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
|
||||
# For every ATen op that accepts any args in this list,
|
||||
# the arg itself can impact the strides (and potentially the sharding strategy)
|
||||
# of the output tensor.
|
||||
@ -113,9 +68,7 @@ def register_op_strategy(
|
||||
"memory_format",
|
||||
]
|
||||
|
||||
def wrapper(
|
||||
impl: Callable[[OpSchema], StrategyType],
|
||||
) -> Callable[[OpSchema], StrategyType]:
|
||||
def wrapper(impl):
|
||||
if isinstance(op, list):
|
||||
overloads = op
|
||||
else:
|
||||
@ -206,10 +159,7 @@ def prod(xs: Iterable[int]) -> int:
|
||||
|
||||
|
||||
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
|
||||
"""Check if the spec matches these criteria:
|
||||
* any Shard placements in spec refer to valid tensor dims
|
||||
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
|
||||
"""
|
||||
"""Check if the shape is shardable according to the spec."""
|
||||
# number of shards in each tensor dimension
|
||||
shards_map = [1] * len(shape)
|
||||
for i, placement in enumerate(spec.placements):
|
||||
|
||||
@ -226,10 +226,8 @@ class PythonCode:
|
||||
# Values in global scope during execution of `src_def`.
|
||||
globals: dict[str, Any]
|
||||
# Optional mapping from the forward function's line number to
|
||||
# node index. Line number starts at the prologue (i.e. forward()).
|
||||
# node index.
|
||||
_lineno_map: Optional[dict[int, Optional[int]]]
|
||||
# The line number of prologue in fn_code
|
||||
_prologue_start: int = 0
|
||||
|
||||
|
||||
def _format_target(base: str, target: str) -> str:
|
||||
@ -856,14 +854,7 @@ class CodeGen:
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
# The +4 accounts for the empty lines before prologue in fn_code
|
||||
prologue_start = wrap_stmts.count("\n") + 4
|
||||
return PythonCode(
|
||||
fn_code,
|
||||
globals_,
|
||||
_lineno_map=lineno_map,
|
||||
_prologue_start=prologue_start,
|
||||
)
|
||||
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
|
||||
|
||||
|
||||
# Ideally, we'd like to refactor all of the pytree logic into this codegen
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import linecache
|
||||
import os
|
||||
@ -38,7 +36,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
|
||||
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
|
||||
|
||||
|
||||
# Normal exec loses the source code, however we can work with
|
||||
@ -64,13 +61,7 @@ class _EvalCacheLoader:
|
||||
|
||||
key = self._get_key()
|
||||
if co_fields:
|
||||
if "co_filename" in co_fields:
|
||||
# If only co_filename is provided, use it directly as the key
|
||||
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
|
||||
key = co_fields["co_filename"]
|
||||
else:
|
||||
# Full co_fields with all three components
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
self.eval_cache[key] = src
|
||||
|
||||
# Don't mutate globals so that this loader is only used
|
||||
@ -362,36 +353,6 @@ def _print_readable(
|
||||
return output
|
||||
|
||||
|
||||
def _metadata_hash(code: str, node_metadata: dict) -> str:
|
||||
"""
|
||||
Create a content-addressed hash from code and metadata.
|
||||
|
||||
Args:
|
||||
code: The source code string
|
||||
lineno_map: Mapping from line numbers to node indices
|
||||
node_metadata: Metadata for each node
|
||||
|
||||
Returns:
|
||||
A 51-character base32-encoded hash
|
||||
"""
|
||||
import json
|
||||
|
||||
# Create a deterministic string representation of all components
|
||||
# We use JSON to ensure consistent serialization
|
||||
hash_data = {
|
||||
"code": code,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
hashing_str = json.dumps(hash_data).encode("utf-8")
|
||||
|
||||
# [:51] to strip off the "Q====" suffix common to every hash value.
|
||||
return (
|
||||
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
|
||||
.decode("utf-8")
|
||||
.lower()
|
||||
)
|
||||
|
||||
|
||||
class _WrappedCall:
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
@ -864,47 +825,9 @@ class {module_name}(torch.nn.Module):
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
node_metadata[i] = {
|
||||
"name": node.name,
|
||||
"op": node.op,
|
||||
"target": str(node.target),
|
||||
"stack_trace": node.meta.get("stack_trace", None),
|
||||
}
|
||||
|
||||
# Generate a content-addressed filename based on hash of code and metadata
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
co_fields = {
|
||||
"co_filename": filename,
|
||||
}
|
||||
|
||||
# Store metadata in global in-memory registry
|
||||
metadata = {
|
||||
"lineno_map": python_code._lineno_map,
|
||||
"prologue_start": python_code._prologue_start,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
|
||||
# Register metadata in the global registry
|
||||
from torch.fx.traceback import _register_fx_metadata
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -38,28 +38,6 @@ current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
# =============================================================================
|
||||
# FX Metadata Registry for Memory Profiler
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
This is called automatically during graph module compilation to store metadata
|
||||
for later use by memory profiler augmentation.
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class NodeSourceAction(Enum):
|
||||
|
||||
@ -17,5 +17,230 @@ def is_stdlib_module(module: str) -> bool:
|
||||
|
||||
|
||||
def _get_stdlib_modules():
|
||||
assert sys.version_info >= (3, 10)
|
||||
return sys.stdlib_module_names
|
||||
if sys.version_info.major == 3: # noqa: UP036
|
||||
if sys.version_info.minor == 9:
|
||||
return stdlib3_9
|
||||
if sys.version_info.minor >= 10: # noqa: YTT204
|
||||
return sys.stdlib_module_names # type: ignore[attr-defined]
|
||||
elif sys.version_info.major > 3: # noqa: UP036
|
||||
return sys.stdlib_module_names # type: ignore[attr-defined]
|
||||
|
||||
raise RuntimeError(f"Unsupported Python version: {sys.version_info}")
|
||||
|
||||
|
||||
stdlib3_9 = {
|
||||
"_thread",
|
||||
"abc",
|
||||
"aifc",
|
||||
"argparse",
|
||||
"array",
|
||||
"ast",
|
||||
"asynchat",
|
||||
"asyncio",
|
||||
"asyncore",
|
||||
"atexit",
|
||||
"audioop",
|
||||
"base64",
|
||||
"bdb",
|
||||
"binascii",
|
||||
"binhex",
|
||||
"bisect",
|
||||
"builtins",
|
||||
"bz2",
|
||||
"cProfile",
|
||||
"calendar",
|
||||
"cgi",
|
||||
"cgitb",
|
||||
"chunk",
|
||||
"cmath",
|
||||
"cmd",
|
||||
"code",
|
||||
"codecs",
|
||||
"codeop",
|
||||
"collections",
|
||||
"colorsys",
|
||||
"compileall",
|
||||
"concurrent",
|
||||
"configparser",
|
||||
"contextlib",
|
||||
"contextvars",
|
||||
"copy",
|
||||
"copyreg",
|
||||
"crypt",
|
||||
"csv",
|
||||
"ctypes",
|
||||
"curses",
|
||||
"dataclasses",
|
||||
"datetime",
|
||||
"dbm",
|
||||
"decimal",
|
||||
"difflib",
|
||||
"dis",
|
||||
"distutils",
|
||||
"doctest",
|
||||
"email",
|
||||
"encodings",
|
||||
"ensurepip",
|
||||
"enum",
|
||||
"errno",
|
||||
"faulthandler",
|
||||
"fcntl",
|
||||
"filecmp",
|
||||
"fileinput",
|
||||
"fnmatch",
|
||||
"formatter",
|
||||
"fractions",
|
||||
"ftplib",
|
||||
"functools",
|
||||
"gc",
|
||||
"getopt",
|
||||
"getpass",
|
||||
"gettext",
|
||||
"glob",
|
||||
"graphlib",
|
||||
"grp",
|
||||
"gzip",
|
||||
"hashlib",
|
||||
"heapq",
|
||||
"hmac",
|
||||
"html",
|
||||
"http",
|
||||
"imaplib",
|
||||
"imghdr",
|
||||
"imp",
|
||||
"importlib",
|
||||
"inspect",
|
||||
"io",
|
||||
"ipaddress",
|
||||
"itertools",
|
||||
"json",
|
||||
"keyword",
|
||||
"lib2to3",
|
||||
"linecache",
|
||||
"locale",
|
||||
"logging",
|
||||
"lzma",
|
||||
"mailbox",
|
||||
"mailcap",
|
||||
"marshal",
|
||||
"math",
|
||||
"mimetypes",
|
||||
"mmap",
|
||||
"modulefinder",
|
||||
"msilib",
|
||||
"msvcrt",
|
||||
"multiprocessing",
|
||||
"netrc",
|
||||
"nis",
|
||||
"nntplib",
|
||||
"ntpath",
|
||||
"numbers",
|
||||
"operator",
|
||||
"optparse",
|
||||
"os",
|
||||
"ossaudiodev",
|
||||
"parser",
|
||||
"pathlib",
|
||||
"pdb",
|
||||
"pickle",
|
||||
"pickletools",
|
||||
"pipes",
|
||||
"pkgutil",
|
||||
"platform",
|
||||
"plistlib",
|
||||
"poplib",
|
||||
"posix",
|
||||
"posixpath",
|
||||
"pprint",
|
||||
"profile",
|
||||
"pstats",
|
||||
"pty",
|
||||
"pwd",
|
||||
"py_compile",
|
||||
"pyclbr",
|
||||
"pydoc",
|
||||
"queue",
|
||||
"quopri",
|
||||
"random",
|
||||
"re",
|
||||
"readline",
|
||||
"reprlib",
|
||||
"resource",
|
||||
"rlcompleter",
|
||||
"runpy",
|
||||
"sched",
|
||||
"secrets",
|
||||
"select",
|
||||
"selectors",
|
||||
"shelve",
|
||||
"shlex",
|
||||
"shutil",
|
||||
"signal",
|
||||
"site",
|
||||
"smtpd",
|
||||
"smtplib",
|
||||
"sndhdr",
|
||||
"socket",
|
||||
"socketserver",
|
||||
"spwd",
|
||||
"sqlite3",
|
||||
"sre",
|
||||
"sre_compile",
|
||||
"sre_constants",
|
||||
"sre_parse",
|
||||
"ssl",
|
||||
"stat",
|
||||
"statistics",
|
||||
"string",
|
||||
"stringprep",
|
||||
"struct",
|
||||
"subprocess",
|
||||
"sunau",
|
||||
"symbol",
|
||||
"symtable",
|
||||
"sys",
|
||||
"sysconfig",
|
||||
"syslog",
|
||||
"tabnanny",
|
||||
"tarfile",
|
||||
"telnetlib",
|
||||
"tempfile",
|
||||
"termios",
|
||||
"test",
|
||||
"textwrap",
|
||||
"threading",
|
||||
"time",
|
||||
"timeit",
|
||||
"tkinter",
|
||||
"token",
|
||||
"tokenize",
|
||||
"trace",
|
||||
"traceback",
|
||||
"tracemalloc",
|
||||
"tty",
|
||||
"turtle",
|
||||
"turtledemo",
|
||||
"types",
|
||||
"typing",
|
||||
"unicodedata",
|
||||
"unittest",
|
||||
"urllib",
|
||||
"uu",
|
||||
"uuid",
|
||||
"venv",
|
||||
"warnings",
|
||||
"wave",
|
||||
"weakref",
|
||||
"webbrowser",
|
||||
"winreg",
|
||||
"winsound",
|
||||
"wsgiref",
|
||||
"xdrlib",
|
||||
"xml",
|
||||
"xmlrpc",
|
||||
"zipapp",
|
||||
"zipfile",
|
||||
"zipimport",
|
||||
"zlib",
|
||||
"zoneinfo",
|
||||
}
|
||||
|
||||
@ -806,29 +806,7 @@ function format_frames(frames) {
|
||||
}
|
||||
const frame_strings = frames
|
||||
.filter(frameFilter)
|
||||
.map(f => {
|
||||
let frame_str = `${f.filename}:${f.line}:${f.name}`;
|
||||
|
||||
// Add FX debug information if available
|
||||
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
|
||||
const fx_parts = [];
|
||||
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
|
||||
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
|
||||
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
|
||||
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
|
||||
}
|
||||
|
||||
if (f.fx_original_trace) {
|
||||
frame_str += `\n >> Original Model Code:`;
|
||||
const original_lines = f.fx_original_trace.trim().split('\n');
|
||||
// Show all lines of the original trace
|
||||
for (const line of original_lines) {
|
||||
frame_str += `\n ${line}`;
|
||||
}
|
||||
}
|
||||
|
||||
return frame_str;
|
||||
});
|
||||
.map(f => `${f.filename}:${f.line}:${f.name}`);
|
||||
return elideRepeats(frame_strings).join('\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user