Compare commits

..

3 Commits

Author SHA1 Message Date
5b6cc8215f Change python doc push script to print the undocumented modules 2025-10-21 12:30:49 -07:00
1c43c9cfd0 Update 2025-10-21 12:30:49 -07:00
102e0d5437 Test 2025-10-21 12:30:49 -07:00
23 changed files with 442 additions and 839 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.9.5"
"uv==0.8.6"
]
[tool.setuptools]

View File

@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -147,16 +147,15 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -58,10 +58,8 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -89,7 +87,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag(s) with retry
- name: Create and push tag with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -114,23 +112,14 @@ jobs:
return 1
}
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Retry configuration
MAX_RETRIES=5
@ -205,111 +194,31 @@ jobs:
}
}
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
exit 0
else
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
echo "Tag creation failed after all retry attempts"
exit 1
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
else
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi

View File

@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
}
}
// Short circuit on empty result
if (result.numel() == 0) {
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
// end Lt path
} else {
// No Lt, we use a GEMM instead
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
} // end GEMM path
}
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 1024;
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif

View File

@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int64_t chunk_size,
int chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1645000000,0.1
update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1

1 add_loop_eager compile_time_instruction_count 3184000000 3070000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4595000000 4432000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1096000000 1048000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17720000000 17020000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3152000000 3442000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 8301000000 9239000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4958000000 4820968837 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9990000000 9554000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 8126000000 7618000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -207,6 +207,42 @@ templates_path = [
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -67,21 +67,7 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()

View File

@ -288,18 +288,6 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_graph_break(self):
def fn(x):
with torch.fx.traceback.annotate({"pp_stage": 0}):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(10, requires_grad=True)
self.assertEqual(fn(x), opt_fn(x))
if __name__ == "__main__":
run_tests()

View File

@ -6,7 +6,6 @@ import builtins
import collections
import contextlib
import copy
import gc
import functools
import inspect
import io
@ -20,7 +19,6 @@ import traceback
import types
import typing
import unittest
import weakref
import warnings
from math import sqrt
from torch.multiprocessing import Process
@ -1626,25 +1624,6 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x")

View File

@ -7381,10 +7381,6 @@ torch.cuda.synchronize()
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2

View File

@ -8490,14 +8490,6 @@ class TestNNDeviceType(NNTestCase):
y_cuda_contig = pool(x_cuda.contiguous())
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
@onlyCUDA
def test_large_reflect_pad(self, device):
# https://github.com/pytorch/pytorch/issues/165861
x = torch.rand(2**16, 2, device="cuda")
c = F.pad(x, (1, 1), mode="reflect")
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
self.assertEqual(c, c_cpu)
@onlyCUDA
@largeTensorTest("48GB", "cpu")
@largeTensorTest("48GB", "cuda")

View File

@ -247,10 +247,6 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
@ -580,10 +576,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@ -629,10 +621,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@ -670,10 +658,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@ -708,10 +692,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_id(self, dtype) -> None:
N = 512
torch.manual_seed(0)
@ -749,10 +729,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this
@ -778,10 +754,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -798,10 +770,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -840,10 +808,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype)
@ -879,10 +843,6 @@ class TestSparseSemiStructuredTraining(TestCase):
)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
@ -893,10 +853,6 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)

View File

@ -2758,12 +2758,6 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -2810,15 +2810,5 @@
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
],
"GB0279": [
{
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
"Context": "str(self)",
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -1295,16 +1295,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
def fn_name(self):
return "annotate"
def reconstruct_type(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.fx.traceback.annotate escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -1,15 +1,11 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -167,41 +163,7 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
PyObject* _sort_key;
};
static PyObject* NodeBase_new(
@ -211,8 +173,6 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -241,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -260,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -277,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -294,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -358,195 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -578,7 +361,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -385,7 +385,41 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
self._prepend(x)
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -396,7 +430,11 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next._prepend(x)
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
@property
def args(self) -> tuple[Argument, ...]: