mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Compare commits
5 Commits
cpp-docs-d
...
ciflow/roc
| Author | SHA1 | Date | |
|---|---|---|---|
| 272aa7e960 | |||
| cdaece5cef | |||
| a6a06137e1 | |||
| 78f6d47469 | |||
| ccd3919116 |
@ -1,11 +1,15 @@
|
||||
sphinx==7.2.6
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 7.2.6
|
||||
#Pinned versions: 5.3.0
|
||||
|
||||
pytorch_sphinx_theme2==0.2.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.2.0
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.36.0
|
||||
breathe==4.34.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.36.0
|
||||
#Pinned versions: 4.34.0
|
||||
|
||||
exhale==0.3.7
|
||||
exhale==0.2.3
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.3.7
|
||||
#Pinned versions: 0.2.3
|
||||
|
||||
docutils==0.20
|
||||
docutils==0.16
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.20
|
||||
#Pinned versions: 0.16
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -52,13 +56,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==1.3.0
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 1.3.0
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.6.1
|
||||
sphinx-design==0.4.0
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==4.0.1
|
||||
myst-parser==0.18.1
|
||||
|
||||
@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
|
||||
|
||||
make coverage
|
||||
# Now we have the coverage report, we need to make sure it is empty.
|
||||
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
|
||||
# showing the undocumented count in the third column.
|
||||
# Example: | TOTAL | 99.83% | 2 |
|
||||
# Count the number of lines in the file and turn that number into a variable
|
||||
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
|
||||
# Skip the report header by subtracting 2: the header will be output even if
|
||||
# there are no undocumented items.
|
||||
#
|
||||
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
||||
# be documented then removed from there.
|
||||
|
||||
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
|
||||
# The table format is: | Module | Coverage | Undocumented |
|
||||
# Extract the third column (undocumented count) from the TOTAL row
|
||||
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
|
||||
|
||||
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
|
||||
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
|
||||
undocumented=$((lines - 2))
|
||||
if [ $undocumented -lt 0 ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ "$undocumented" -gt 0 ]; then
|
||||
set +x # Disable command echoing for cleaner output
|
||||
echo ""
|
||||
echo "====================="
|
||||
echo "UNDOCUMENTED OBJECTS:"
|
||||
echo "====================="
|
||||
echo ""
|
||||
# Find the line number of the TOTAL row and print only what comes after it
|
||||
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
|
||||
if [ -n "$total_line" ]; then
|
||||
# Print only the detailed list (skip the statistics table)
|
||||
tail -n +$((total_line + 2)) build/coverage/python.txt
|
||||
else
|
||||
# Fallback to showing entire file if TOTAL line not found
|
||||
cat build/coverage/python.txt
|
||||
fi
|
||||
echo ""
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
cat build/coverage/python.txt
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
|
||||
set -x # Re-enable command echoing
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
|
||||
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.0",
|
||||
"13.0": "13.0.2",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,7 +8,6 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -39,6 +41,16 @@ namespace at {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
@ -59,24 +59,6 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -785,6 +767,33 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -797,7 +806,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -814,61 +823,23 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -890,65 +861,24 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -970,62 +900,24 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -1105,47 +997,26 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1160,30 +1031,11 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1308,20 +1160,17 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -206,43 +206,19 @@ templates_path = [
|
||||
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
# torch.cuda._sanitizer
|
||||
"zip_arguments",
|
||||
"zip_by_key",
|
||||
# torch.distributed.autograd
|
||||
"is_available",
|
||||
# torch.distributed.checkpoint.state_dict
|
||||
"gc_context",
|
||||
# torch.distributed.elastic.events
|
||||
"record_rdzv_event",
|
||||
# torch.distributed.elastic.metrics
|
||||
"initialize_metrics",
|
||||
# torch.distributed.elastic.rendezvous.registry
|
||||
@ -3219,11 +3195,6 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
||||
@ -253,6 +253,7 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
||||
@ -1019,28 +1019,6 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
@with_comms
|
||||
def test_as_strided_identity(self):
|
||||
# Test calling as_strided with the same size/stride/offset as input tensor
|
||||
# This should be a no-op but currently fails
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 4, device=self.device_type)
|
||||
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# Get the current size, stride, and storage_offset
|
||||
size = dtensor.size()
|
||||
stride = dtensor.stride()
|
||||
storage_offset = dtensor.storage_offset()
|
||||
|
||||
# Call as_strided with the exact same parameters
|
||||
result = dtensor.as_strided(size, stride, storage_offset)
|
||||
|
||||
# The result should be identical to the input
|
||||
self.assertEqual(result.size(), dtensor.size())
|
||||
self.assertEqual(result.stride(), dtensor.stride())
|
||||
self.assertEqual(result.to_local(), dtensor.to_local())
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -17,12 +16,9 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -33,7 +29,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -46,7 +42,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -62,7 +58,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -79,7 +75,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = self.make_pool(2)
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -87,12 +83,6 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
@ -15,8 +15,9 @@ from torch.testing._internal.common_utils import (
|
||||
is_navi3_arch,
|
||||
parametrize,
|
||||
patch_test_members,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
@ -60,6 +61,11 @@ class TestDecomposeAddMM(torch.nn.Module):
|
||||
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
TEST_XPU,
|
||||
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
|
||||
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
|
||||
)
|
||||
@torch._inductor.config.patch(
|
||||
post_grad_fusion_options={
|
||||
"decompose_mm_pass": {},
|
||||
@ -138,7 +144,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -149,7 +155,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -198,7 +204,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -253,7 +259,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -298,7 +304,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -310,7 +316,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -368,7 +374,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -380,7 +386,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -404,7 +410,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -418,7 +424,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 0
|
||||
if HAS_GPU_AND_TRITON:
|
||||
if HAS_CUDA_AND_TRITON:
|
||||
expected_val = 1 if has_bias else 2
|
||||
|
||||
self.assertEqual(
|
||||
@ -441,8 +447,12 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
_, code = run_and_get_code(foo, input1, input2)
|
||||
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
if GPU_TYPE == "xpu":
|
||||
# only 1 kernel generated on the XPU stack
|
||||
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
|
||||
else:
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
|
||||
def test_check_device(self):
|
||||
m = 5
|
||||
@ -452,7 +462,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
input1 = torch.randn(m, k, device=GPU_TYPE)
|
||||
input2 = torch.randn(k, n, device=GPU_TYPE)
|
||||
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
|
||||
self.assertTrue(check_device(input1, input2))
|
||||
self.assertFalse(check_device(input1, input2, device="cpu"))
|
||||
|
||||
input1 = torch.randn(m, k)
|
||||
|
||||
@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
softmax_load_str in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -15280,7 +15280,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15293,7 +15293,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1826,14 +1826,9 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -7413,140 +7413,6 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,7 +771,6 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -781,7 +780,6 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -359,6 +359,29 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
|
||||
@serialTest()
|
||||
def test_cublas_batch_invariance_blackwell(self, device, dtype):
|
||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
|
||||
with blas_library_context('cublaslt'):
|
||||
N = 2048
|
||||
K = 6144
|
||||
M_max = 32
|
||||
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
|
||||
full = x @ w
|
||||
xx = x[:1]
|
||||
out = xx @ w
|
||||
self.assertEqual(full[:1], out, atol=0., rtol=0.)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
@ -107,76 +107,11 @@ def tensor_to_scale_block(
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
# if amax == 0, entire block = 0, set scale = 0 to ensure elements are
|
||||
# zero'd out correctly (and remove bad effects of / 0)
|
||||
scale[amax == 0] = 0
|
||||
|
||||
# Scale x, noting that blocks where amax == 0 are explicitly 0 now.
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
|
||||
# if amax == 0, all values in the block are 0, scale=0
|
||||
# but we need scale.reciprocal later, which breaks when scale=0...
|
||||
# So. Replace 0 -> 1 in the scale so we don't break things later.
|
||||
# Elements are already zeroed, so don't actually care what the scale
|
||||
# is, as long as it's not inf/nan.
|
||||
scale[scale == 0] = 1.
|
||||
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def hp_from_128x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
M, K = orig_shape
|
||||
x_lp = x_lp.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_128x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
M, K = orig_shape
|
||||
x = x.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_lp = x * x_scale
|
||||
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
def hp_from_1x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
x_lp = x_lp.reshape(x_lp.shape[0], x_lp.shape[-1] // 128, 128)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale.unsqueeze(-1)
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_1x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(x.shape[0], x.shape[-1] // 128, 128)
|
||||
x_lp = x * x_scale.unsqueeze(-1)
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
# cublas requires specific padding for 128x128 scales, see:
|
||||
# https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
# Notably L = ceil_div(K, 128),
|
||||
# L4 = round_up(L, 4),
|
||||
# and then for A/B the shape must be
|
||||
# scale: [L4, ceil_div({M,N}, 128) and K/L/L4-major in memory.
|
||||
#
|
||||
# This routine pads L -> L4
|
||||
def _pad_128x128_scales(scale: torch.Tensor) -> (torch.Tensor, int):
|
||||
# scale is either [L4, ceil_div(M, 128)] or [L4, ceil_div(N, 128)], stride: [1, L4]
|
||||
# However, we get passed it as [ceil_div(M, 128), L] or [ceil_div(N, 128), L]
|
||||
# so check inner dim % 4, and pad if necessary
|
||||
pad_amount = scale.shape[-1] % 4
|
||||
|
||||
if pad_amount == 0:
|
||||
return scale, 0
|
||||
else:
|
||||
pad_amount = 4 - pad_amount
|
||||
return pad(scale, (0, pad_amount), "constant", 0), pad_amount
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
@ -209,36 +144,42 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX formats
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MX w/swizzle (NVIDIA)
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MX w/o swizzle (AMD)
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1311,6 +1252,7 @@ class TestFP8Matmul(TestCase):
|
||||
else:
|
||||
test()
|
||||
|
||||
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
@ -1319,224 +1261,59 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
@parametrize("M,N,K", [
|
||||
# Nice size
|
||||
(256, 768, 512),
|
||||
# Requires padding for 128x128 scale
|
||||
(384, 128, 1280),
|
||||
# M=N=K for eyes test
|
||||
(512, 512, 512),
|
||||
])
|
||||
@parametrize("test_case", [
|
||||
"x_eye_b_eye",
|
||||
"x_ones_y_ones_calc_scales",
|
||||
"x_ones_y_ones_set_scales",
|
||||
"x_ones_y_ones_modify_scales",
|
||||
"data_random_scales_one",
|
||||
"data_random_calc_scales",
|
||||
])
|
||||
def test_scaled_mm_block_wise_numerics(self, output_dtype, lhs_block, rhs_block, M, N, K, test_case):
|
||||
"""
|
||||
subsume test_scaled_mm_vs_emulated_block_wise for random inputs, random scales,
|
||||
do some other functional tests as well.
|
||||
|
||||
# Inputs (as generated are):
|
||||
# A: [M, K]
|
||||
# B: [N, K]
|
||||
# then scales are, for the 3 combinations:
|
||||
# 1x128 x 1x128:
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [N, K // 128], stride: [1, N] -> scale.t().contiguous().t()
|
||||
# 1x128 x 128x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [L4, N // 128], stride: [1, L4] -> scale.t()
|
||||
# 128x128 x 1x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [L4, M // 128], stride: [1, L4]
|
||||
# Bs: [N, K // 128], stride: [1, N]
|
||||
"""
|
||||
@parametrize("M,N,K", [(256, 768, 512)])
|
||||
@with_tf32_off
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _adjust_lhs_scale(x_fp8, x_scales, lhs_block):
|
||||
M, K = x_fp8.shape
|
||||
x_scales_original = x_scales.clone()
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
x_hp = hp_from_1x128(x_fp8, x_scales_original)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
x_hp = hp_from_128x128(x_fp8, x_scales_original)
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
return x_hp, lhs_recipe, x_scales, x_scales_original
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
def _adjust_rhs_scale(y_fp8, y_scales, rhs_block):
|
||||
N, K = y_fp8.shape
|
||||
y_scales_original = y_scales.clone()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
y_hp = hp_from_1x128(y_fp8, y_scales_original)
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
y_hp = hp_from_128x128(y_fp8, y_scales_original)
|
||||
|
||||
return y_hp, rhs_recipe, y_scales, y_scales_original
|
||||
|
||||
def _build_lhs(x, lhs_block):
|
||||
M, K = x.shape
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
x_scales_original = x_scales
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
|
||||
return x_hp, x_recipe, x_fp8, x_scales, x_scales_original
|
||||
|
||||
def _build_rhs(y, rhs_block):
|
||||
N, K = y.shape
|
||||
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
|
||||
return y_hp, y_recipe, y_fp8, y_scales, y_scales_original
|
||||
|
||||
def _run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original):
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales.reciprocal(),
|
||||
scale_recipe_a=x_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales.reciprocal(),
|
||||
scale_recipe_b=y_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_emulated.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated.to(output_dtype), atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
def _build_constant_scale(t, block, val):
|
||||
M, K = t.shape
|
||||
|
||||
if block == 1:
|
||||
scale_shape = M, K // 128
|
||||
else:
|
||||
scale_shape = M // 128, K // 128
|
||||
|
||||
scale = torch.full(scale_shape, val, device='cuda')
|
||||
|
||||
return scale
|
||||
|
||||
def hp_to_scaled(t, scale, block):
|
||||
if block == 1:
|
||||
return hp_to_1x128(t, scale)
|
||||
else:
|
||||
return hp_to_128x128(t, scale)
|
||||
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
|
||||
if test_case == "x_eye_b_eye":
|
||||
if M != K or M != N:
|
||||
return unittest.skip("a_eye_b_eye only defined for M = N = K")
|
||||
x = torch.eye(M, device='cuda')
|
||||
y = torch.eye(M, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case == "x_ones_y_ones_calc_scales":
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case in ["x_ones_y_ones_set_scales", "x_ones_y_ones_modify_scales"]:
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
if "modify" in test_case:
|
||||
x_scales[0, 0] = 4.
|
||||
y_scales[-1, -1] = 4.
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_scales_one":
|
||||
x = torch.randint(0, 255, (M, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
y = torch.randint(0, 255, (N, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_calc_scales":
|
||||
# Note: Old test_scaled_mm_vs_emulated_block_wise test case
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype) * 1e-3
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
raise ValueError("Unknown test-case passed")
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
|
||||
_run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@ -1558,30 +1335,18 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
x_scales_original = x_scales
|
||||
y_scales_original = y_scales
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
|
||||
# Verify that actual F8 mm doesn't error
|
||||
scaled_mm_wrap(
|
||||
@ -1589,20 +1354,13 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales,
|
||||
scale_recipe_a=lhs_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales,
|
||||
scale_b=y_scales.t(),
|
||||
scale_recipe_b=rhs_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Verify that emulated F8 mm doesn't error
|
||||
mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@ -1862,7 +1620,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1876,12 +1634,8 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -2059,7 +1813,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
115
test/test_xpu.py
115
test/test_xpu.py
@ -14,8 +14,10 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyXPU,
|
||||
OpDTypes,
|
||||
ops,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import ops_and_refs
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -72,8 +74,6 @@ _xpu_computation_ops = [
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpu(TestCase):
|
||||
expandable_segments = False
|
||||
|
||||
def test_device_behavior(self):
|
||||
current_device = torch.xpu.current_device()
|
||||
torch.xpu.set_device(current_device)
|
||||
@ -385,6 +385,56 @@ if __name__ == "__main__":
|
||||
torch.xpu.set_rng_state(g_state0)
|
||||
self.assertEqual(2024, torch.xpu.initial_seed())
|
||||
|
||||
@onlyXPU
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@onlyXPU
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_serialization_array_with_storage(self):
|
||||
x = torch.randn(5, 5).xpu()
|
||||
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
|
||||
@ -420,8 +470,6 @@ if __name__ == "__main__":
|
||||
self.assertEqual(copy.get_device(), original.get_device())
|
||||
|
||||
def test_out_of_memory(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
tensor = torch.zeros(1024, device="xpu") # noqa: F841
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
|
||||
@ -431,8 +479,6 @@ if __name__ == "__main__":
|
||||
torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu")
|
||||
|
||||
def test_raises_oom(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
torch.xpu.memory.empty_cache()
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")
|
||||
@ -545,7 +591,7 @@ if __name__ == "__main__":
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(
|
||||
@skipXPUIf(
|
||||
int(torch.version.xpu) < 20250000,
|
||||
"Test requires SYCL compiler version 2025.0.0 or newer.",
|
||||
)
|
||||
@ -593,8 +639,6 @@ if __name__ == "__main__":
|
||||
self.assertTrue(b"libsycl.so" in result)
|
||||
|
||||
def test_dlpack_conversion(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping DLPack test for expandable segments allocator.")
|
||||
x = make_tensor((5,), dtype=torch.float32, device="xpu")
|
||||
if IS_WINDOWS and int(torch.version.xpu) < 20250000:
|
||||
with self.assertRaisesRegex(
|
||||
@ -608,58 +652,7 @@ if __name__ == "__main__":
|
||||
self.assertEqual(z, x)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpuOps(TestCase):
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True)
|
||||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
# Owner(s): ["module: intel"]
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from test_xpu import TestXpu, TestXpuOpsXPU # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.xpu.is_available() and not IS_WINDOWS:
|
||||
get_disabled_tests(".")
|
||||
|
||||
torch._C._accelerator_setAllocatorSettings("expandable_segments:True")
|
||||
TestXpu.expandable_segments = True
|
||||
|
||||
run_tests()
|
||||
@ -739,12 +739,6 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -423,10 +423,6 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
|
||||
# nccl estimator requires real process group
|
||||
return None
|
||||
|
||||
fn = fx_node.target
|
||||
assert isinstance(fn, torch._ops.OpOverload)
|
||||
with torch.distributed._time_estimator(group=pg) as time_estimator:
|
||||
|
||||
@ -24,7 +24,6 @@ from typing_extensions import Never, ParamSpec
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch._inductor.compile_worker.timer import Timer
|
||||
from torch._inductor.compile_worker.tracked_process_pool import (
|
||||
TrackedProcessPoolExecutor,
|
||||
)
|
||||
@ -133,7 +132,6 @@ class SubprocPool:
|
||||
nprocs: int,
|
||||
pickler: Optional[SubprocPickler] = None,
|
||||
kind: SubprocKind = SubprocKind.FORK,
|
||||
quiesce: bool = False,
|
||||
) -> None:
|
||||
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
|
||||
self.pickler = pickler or SubprocPickler()
|
||||
@ -218,13 +216,6 @@ class SubprocPool:
|
||||
"pytorch.wait_counter.subproc_pool.first_job"
|
||||
).guard()
|
||||
|
||||
if quiesce:
|
||||
self.timer: Optional[Timer] = Timer(
|
||||
config.quiesce_async_compile_time, self.quiesce
|
||||
)
|
||||
else:
|
||||
self.timer = None
|
||||
|
||||
# Start thread last to ensure all member variables are initialized
|
||||
# before any access.
|
||||
self.read_thread.start()
|
||||
@ -297,8 +288,6 @@ class SubprocPool:
|
||||
with self.futures_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.record_call()
|
||||
if isinstance(result, _SubprocExceptionInfo):
|
||||
# An exception occurred in the submitted job
|
||||
self.pending_futures[job_id].set_exception(
|
||||
@ -333,8 +322,6 @@ class SubprocPool:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.quit()
|
||||
self.running = False
|
||||
self.running_waitcounter.__exit__()
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
|
||||
@ -17,7 +17,7 @@ class Timer:
|
||||
self.background_thread: Optional[Thread] = None
|
||||
self.last_called: Optional[float] = None
|
||||
self.duration = duration
|
||||
self.sleep_time = duration / 2
|
||||
self.sleep_time = 60
|
||||
self.call = call
|
||||
self.exit = False
|
||||
|
||||
|
||||
@ -964,11 +964,6 @@ quiesce_async_compile_pool: bool = Config(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Time in seconds to wait before quiescing
|
||||
quiesce_async_compile_time: int = Config(
|
||||
default=60,
|
||||
)
|
||||
|
||||
# Whether or not to enable statically launching CUDA kernels
|
||||
# compiled by triton (instead of using triton's own launcher)
|
||||
use_static_cuda_launcher: bool = static_cuda_launcher_default()
|
||||
|
||||
@ -66,9 +66,7 @@ def should_decompose_bmm(mat1, mat2) -> bool:
|
||||
return False
|
||||
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
|
||||
return False
|
||||
if check_device(mat1, mat2, device="cuda") or check_device(
|
||||
mat1, mat2, device="xpu"
|
||||
):
|
||||
if check_device(mat1, mat2, device="cuda"):
|
||||
if mat1.shape[0] < min_first_dimension_decomposition:
|
||||
return False
|
||||
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
|
||||
@ -132,10 +130,7 @@ def should_decompose_mm(mat1, mat2) -> bool:
|
||||
"skip_dynamic_shape_dim_check", False
|
||||
):
|
||||
return (
|
||||
(
|
||||
check_device(mat1, mat2, device="cuda")
|
||||
or check_device(mat1, mat2, device="xpu")
|
||||
)
|
||||
check_device(mat1, mat2, device="cuda")
|
||||
and statically_known_true(
|
||||
mat1.shape[0] >= min_first_dimension_decomposition
|
||||
)
|
||||
@ -156,10 +151,7 @@ def should_decompose_mm(mat1, mat2) -> bool:
|
||||
# case 2: we decompose mm if the input is dynamic shape
|
||||
else:
|
||||
return (
|
||||
(
|
||||
check_device(mat1, mat2, device="cuda")
|
||||
or check_device(mat1, mat2, device="xpu")
|
||||
)
|
||||
check_device(mat1, mat2, device="cuda")
|
||||
and (
|
||||
statically_known_true(
|
||||
mat1.shape[0] >= min_first_dimension_decomposition
|
||||
|
||||
@ -51,8 +51,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1510,10 +1510,8 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -549,70 +549,6 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
|
||||
@ -31,8 +31,10 @@ template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
@ -73,8 +75,10 @@ template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
@ -129,8 +133,10 @@ template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
@ -184,8 +190,10 @@ template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
@ -201,8 +209,10 @@ template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
@ -239,8 +249,10 @@ template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
@ -297,8 +309,10 @@ template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -336,8 +350,10 @@ template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
@ -4,14 +4,12 @@ r"""This package adds support for device memory management implemented in CUDA."
|
||||
import collections
|
||||
import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing_extensions import deprecated, NotRequired
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -31,60 +29,6 @@ if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
# Type definitions for memory profiler
|
||||
class _Frame(TypedDict):
|
||||
"""Frame information from memory profiler snapshots."""
|
||||
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Fields added by FX augmentation (optional)
|
||||
fx_node_op: NotRequired[str]
|
||||
fx_node_name: NotRequired[str]
|
||||
fx_node_target: NotRequired[str]
|
||||
fx_original_trace: NotRequired[str]
|
||||
|
||||
|
||||
class _Block(TypedDict):
|
||||
"""Memory block information."""
|
||||
|
||||
size: int
|
||||
requested_size: int
|
||||
address: int
|
||||
state: str
|
||||
frames: list[_Frame]
|
||||
|
||||
|
||||
class _Segment(TypedDict):
|
||||
"""Memory segment information."""
|
||||
|
||||
address: int
|
||||
total_size: int
|
||||
stream: int
|
||||
segment_type: str
|
||||
allocated_size: int
|
||||
active_size: int
|
||||
blocks: list[_Block]
|
||||
|
||||
|
||||
class _TraceEntry(TypedDict):
|
||||
"""Memory trace entry information."""
|
||||
|
||||
action: str
|
||||
addr: NotRequired[int]
|
||||
frames: list[_Frame]
|
||||
size: int
|
||||
stream: int
|
||||
device_free: NotRequired[int]
|
||||
|
||||
|
||||
class _Snapshot(TypedDict):
|
||||
"""Memory snapshot structure."""
|
||||
|
||||
segments: list[_Segment]
|
||||
device_traces: NotRequired[list[list[_TraceEntry]]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
@ -1020,120 +964,7 @@ def _record_memory_history_impl(
|
||||
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _augment_frames(frames: list[_Frame]) -> int:
|
||||
"""
|
||||
Augment a list of frames with FX debug information.
|
||||
|
||||
Args:
|
||||
frames: List of frame dictionaries to augment
|
||||
|
||||
Returns:
|
||||
The count of frames that were augmented.
|
||||
"""
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
|
||||
# Regex pattern to match FX generated files
|
||||
_FX_GENERATED_PATTERN = re.compile(
|
||||
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
|
||||
)
|
||||
|
||||
count = 0
|
||||
if not frames:
|
||||
return count
|
||||
|
||||
for frame in frames:
|
||||
if "filename" in frame and "line" in frame:
|
||||
filename = frame["filename"]
|
||||
lineno = frame["line"]
|
||||
|
||||
# Check if this looks like an FX generated file
|
||||
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
|
||||
continue
|
||||
|
||||
# Look up metadata from the global registry
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
metadata = _FX_METADATA_REGISTRY.get(filename)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
lineno_map = metadata.get("lineno_map", {})
|
||||
node_metadata = metadata.get("node_metadata", {})
|
||||
prologue_start = metadata.get("prologue_start", 0)
|
||||
|
||||
# Get the node index for this line
|
||||
node_idx = lineno_map.get(lineno - prologue_start)
|
||||
|
||||
if node_idx is not None and node_idx in node_metadata:
|
||||
node_info = node_metadata[node_idx]
|
||||
original_trace = node_info.get("stack_trace")
|
||||
node_op = node_info.get("op")
|
||||
node_name = node_info.get("name")
|
||||
node_target = node_info.get("target")
|
||||
|
||||
# Always add node metadata
|
||||
frame["fx_node_op"] = node_op
|
||||
frame["fx_node_name"] = node_name
|
||||
frame["fx_node_target"] = str(node_target)
|
||||
|
||||
# Add original trace if available
|
||||
if original_trace:
|
||||
frame["fx_original_trace"] = original_trace
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def _augment_memory_snapshot_stack_traces(
|
||||
snapshot: str | _Snapshot,
|
||||
) -> _Snapshot:
|
||||
"""
|
||||
Augment a memory snapshot with original source stack traces from FX metadata.
|
||||
|
||||
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
|
||||
that is populated during graph module compilation. It must be called in the same
|
||||
Python process where the FX graphs were compiled. It cannot be used to augment
|
||||
snapshots loaded from disk in a different process.
|
||||
|
||||
Args:
|
||||
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
|
||||
|
||||
Returns:
|
||||
The augmented snapshot dictionary with fx_node_op, fx_node_name,
|
||||
fx_original_trace, and fx_node_info fields added to frames
|
||||
"""
|
||||
|
||||
snapshot_dict: _Snapshot
|
||||
if isinstance(snapshot, str):
|
||||
# Load the memory snapshot
|
||||
with open(snapshot, "rb") as f:
|
||||
snapshot_dict = cast(_Snapshot, pickle.load(f))
|
||||
else:
|
||||
snapshot_dict = snapshot
|
||||
|
||||
# Process stack traces in the snapshot
|
||||
augmented_count = 0
|
||||
|
||||
# Process blocks in segments (for regular allocations)
|
||||
if "segments" in snapshot_dict:
|
||||
for segment in snapshot_dict["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
augmented_count += _augment_frames(block["frames"])
|
||||
|
||||
# Process device traces (for memory history)
|
||||
if "device_traces" in snapshot_dict:
|
||||
for trace_list in snapshot_dict["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
augmented_count += _augment_frames(trace_entry["frames"])
|
||||
|
||||
return snapshot_dict
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
def _snapshot(device: "Device" = None):
|
||||
"""Save a snapshot of CUDA memory state at the time it was called.
|
||||
|
||||
The state is represented as a dictionary with the following structure.
|
||||
@ -1181,11 +1012,6 @@ def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Optional FX debug fields (present when augment_with_fx_traces=True
|
||||
# and the frame corresponds to FX-generated code)
|
||||
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
|
||||
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
|
||||
fx_original_trace: str # Original model source code stack trace
|
||||
|
||||
|
||||
class TraceEntry(TypedDict):
|
||||
@ -1215,23 +1041,13 @@ def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
device_free: int # only present for OOM, the amount of
|
||||
# memory cuda still reports to be free
|
||||
|
||||
Args:
|
||||
device: Device to capture snapshot for. If None, captures for current device.
|
||||
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
|
||||
that maps generated FX code back to original model source code.
|
||||
This adds fx_node_op, fx_node_name, fx_original_trace, and
|
||||
fx_node_info fields to Frame objects. Default: False.
|
||||
|
||||
Returns:
|
||||
The Snapshot dictionary object
|
||||
"""
|
||||
s = _C._cuda_memorySnapshot(None)
|
||||
if augment_with_fx_traces:
|
||||
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
|
||||
return s
|
||||
return _C._cuda_memorySnapshot(None)
|
||||
|
||||
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
"""
|
||||
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
||||
|
||||
@ -1243,14 +1059,8 @@ def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False
|
||||
|
||||
Args:
|
||||
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
||||
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
|
||||
before dumping. This maps generated FX code stack traces
|
||||
back to original model source code. Defaults to False.
|
||||
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
|
||||
during augmentation. Defaults to False.
|
||||
"""
|
||||
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
|
||||
|
||||
s = _snapshot()
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(s, f)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.tensor._api as dtensor
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch._library.utils import fill_defaults
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType
|
||||
@ -35,23 +34,6 @@ aten = torch.ops.aten
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def as_strided_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
):
|
||||
args, kwargs = fill_defaults(op_call._schema, args, kwargs)
|
||||
assert not kwargs
|
||||
tensor, size, stride, storage_offset = args
|
||||
if (
|
||||
tensor.size() == tuple(size)
|
||||
and tensor.stride() == tuple(stride)
|
||||
and (storage_offset is None or tensor.storage_offset() == storage_offset)
|
||||
):
|
||||
return torch.ops.aten.alias.default(tensor)
|
||||
raise RuntimeError("as_strided not supported with DTensor")
|
||||
|
||||
|
||||
def is_same_size_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
@ -139,7 +121,6 @@ class OpDispatcher:
|
||||
aten.convolution.default: convolution_handler,
|
||||
aten.convolution_backward.default: convolution_backward_handler,
|
||||
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
|
||||
aten.as_strided.default: as_strided_handler,
|
||||
}
|
||||
|
||||
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
|
||||
|
||||
@ -84,7 +84,6 @@ register_op_strategy(
|
||||
aten.clone.default,
|
||||
aten.contiguous.default,
|
||||
aten.detach.default,
|
||||
aten.alias.default,
|
||||
aten.fill_.Scalar,
|
||||
aten.view.dtype,
|
||||
aten.zero_.default,
|
||||
|
||||
@ -226,10 +226,8 @@ class PythonCode:
|
||||
# Values in global scope during execution of `src_def`.
|
||||
globals: dict[str, Any]
|
||||
# Optional mapping from the forward function's line number to
|
||||
# node index. Line number starts at the prologue (i.e. forward()).
|
||||
# node index.
|
||||
_lineno_map: Optional[dict[int, Optional[int]]]
|
||||
# The line number of prologue in fn_code
|
||||
_prologue_start: int = 0
|
||||
|
||||
|
||||
def _format_target(base: str, target: str) -> str:
|
||||
@ -856,14 +854,7 @@ class CodeGen:
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
# The +4 accounts for the empty lines before prologue in fn_code
|
||||
prologue_start = wrap_stmts.count("\n") + 4
|
||||
return PythonCode(
|
||||
fn_code,
|
||||
globals_,
|
||||
_lineno_map=lineno_map,
|
||||
_prologue_start=prologue_start,
|
||||
)
|
||||
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
|
||||
|
||||
|
||||
# Ideally, we'd like to refactor all of the pytree logic into this codegen
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import linecache
|
||||
import os
|
||||
@ -38,7 +36,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
|
||||
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
|
||||
|
||||
|
||||
# Normal exec loses the source code, however we can work with
|
||||
@ -64,13 +61,7 @@ class _EvalCacheLoader:
|
||||
|
||||
key = self._get_key()
|
||||
if co_fields:
|
||||
if "co_filename" in co_fields:
|
||||
# If only co_filename is provided, use it directly as the key
|
||||
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
|
||||
key = co_fields["co_filename"]
|
||||
else:
|
||||
# Full co_fields with all three components
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
self.eval_cache[key] = src
|
||||
|
||||
# Don't mutate globals so that this loader is only used
|
||||
@ -362,36 +353,6 @@ def _print_readable(
|
||||
return output
|
||||
|
||||
|
||||
def _metadata_hash(code: str, node_metadata: dict) -> str:
|
||||
"""
|
||||
Create a content-addressed hash from code and metadata.
|
||||
|
||||
Args:
|
||||
code: The source code string
|
||||
lineno_map: Mapping from line numbers to node indices
|
||||
node_metadata: Metadata for each node
|
||||
|
||||
Returns:
|
||||
A 51-character base32-encoded hash
|
||||
"""
|
||||
import json
|
||||
|
||||
# Create a deterministic string representation of all components
|
||||
# We use JSON to ensure consistent serialization
|
||||
hash_data = {
|
||||
"code": code,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
hashing_str = json.dumps(hash_data).encode("utf-8")
|
||||
|
||||
# [:51] to strip off the "Q====" suffix common to every hash value.
|
||||
return (
|
||||
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
|
||||
.decode("utf-8")
|
||||
.lower()
|
||||
)
|
||||
|
||||
|
||||
class _WrappedCall:
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
@ -864,47 +825,9 @@ class {module_name}(torch.nn.Module):
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
node_metadata[i] = {
|
||||
"name": node.name,
|
||||
"op": node.op,
|
||||
"target": str(node.target),
|
||||
"stack_trace": node.meta.get("stack_trace", None),
|
||||
}
|
||||
|
||||
# Generate a content-addressed filename based on hash of code and metadata
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
co_fields = {
|
||||
"co_filename": filename,
|
||||
}
|
||||
|
||||
# Store metadata in global in-memory registry
|
||||
metadata = {
|
||||
"lineno_map": python_code._lineno_map,
|
||||
"prologue_start": python_code._prologue_start,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
|
||||
# Register metadata in the global registry
|
||||
from torch.fx.traceback import _register_fx_metadata
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -38,28 +38,6 @@ current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
# =============================================================================
|
||||
# FX Metadata Registry for Memory Profiler
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
This is called automatically during graph module compilation to store metadata
|
||||
for later use by memory profiler augmentation.
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class NodeSourceAction(Enum):
|
||||
|
||||
@ -806,29 +806,7 @@ function format_frames(frames) {
|
||||
}
|
||||
const frame_strings = frames
|
||||
.filter(frameFilter)
|
||||
.map(f => {
|
||||
let frame_str = `${f.filename}:${f.line}:${f.name}`;
|
||||
|
||||
// Add FX debug information if available
|
||||
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
|
||||
const fx_parts = [];
|
||||
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
|
||||
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
|
||||
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
|
||||
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
|
||||
}
|
||||
|
||||
if (f.fx_original_trace) {
|
||||
frame_str += `\n >> Original Model Code:`;
|
||||
const original_lines = f.fx_original_trace.trim().split('\n');
|
||||
// Show all lines of the original trace
|
||||
for (const line of original_lines) {
|
||||
frame_str += `\n ${line}`;
|
||||
}
|
||||
}
|
||||
|
||||
return frame_str;
|
||||
});
|
||||
.map(f => `${f.filename}:${f.line}:${f.name}`);
|
||||
return elideRepeats(frame_strings).join('\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user