Compare commits

...

11 Commits

Author SHA1 Message Date
e7592f4005 [CI] Move the periodic debug tests to newer runner (#165158)
Previously g3 = NVIDIA Tesla M60
Now g6 = NVIDIA L4
Also change cuda arch list accordingly

Pros:
More memory, newer GPU

Cons:
That was one of the few remaining tests on g3 runners, so we probably lost coverage?

We can probably run more tests in parallel now but I'm not going to do that here

Disabled a bunch of sparse tests and nestedtensor tests that were previously skipped due to not having sufficient hardware?  They are now failing with
```
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3293, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3292, in wrapper
    with policy():
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2532, in __enter__
    self.beforeStreams[-1].synchronize()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/streams.py", line 105, in synchronize
    super().synchronize()
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from stream_synchronize at /var/lib/jenkins/workspace/c10/cuda/CUDAFunctions.h:120 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) [clone .cold] from CUDAException.cpp:0
#7 THCPStream_synchronize(_object*, _object*) from Stream.cpp:0
#8 cfunction_vectorcall_NOARGS from /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:489
#9 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#10 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
#11 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#12 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
```
when run with cuda launch blocking I got a ton of stuff like
```

/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [2,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [3,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,4,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,4,0] Assertion `value < upper_bound` failed.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165158
Approved by: https://github.com/seemethere
2025-10-21 21:28:12 +00:00
d334c3649d [CUDA] fix reflection padding for large batch size (#165942)
Fixes [#165861](https://github.com/pytorch/pytorch/issues/165861)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165942
Approved by: https://github.com/eqy
2025-10-21 21:07:38 +00:00
9f82535c5a [ROCm] [Normalization] Update block size (#165941)
* Seeing upto 6x improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941
Approved by: https://github.com/jeffdaily
2025-10-21 20:53:05 +00:00
5b35fc8777 Support multiple commits on push events in trunk tagging workflow (#165937)
Context:
* this workflow is used to create tags like `trunk/{sha}` for all `main` commits
* those tags are used by [autorevert](https://github.com/pytorch/test-infra/blob/main/aws/lambda/pytorch-auto-revert/README.md) to rerun selected workflows

Problem: currently the workflow creates only a single tag per push event, while ghstack pushes multiple commits per single push.

This PR supports tag creation for all commits in the push event.

Complimentary autorevert PR: https://github.com/pytorch/test-infra/pull/7291

---

### Testing

I created an identical copy of this workflow in my personal repo: https://github.com/izaitsevfb/pr-head-test/actions/workflows/trunk-tagging.yml

See action runs there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165937
Approved by: https://github.com/huydhn
2025-10-21 20:52:34 +00:00
2f38eece7c [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 20:48:12 +00:00
830e789a55 [dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing
error. I am not sure why we can't reconstruct cleanly, it has to do with
the input being a dict, while other supported ctx managers take bools.

Fixing that is for another day. Lets give a good error message for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006
Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
2025-10-21 20:48:04 +00:00
ad4dc52bf6 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit 4e643422f63a3cdd71bd141615f98de6bb54d15f.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/albanD due to Breaks lint ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3429426503))
2025-10-21 20:24:14 +00:00
dac9ed9790 Bump uv from 0.8.6 to 0.9.5 in /.ci/lumen_cli (#166017)
Bumps [uv](https://github.com/astral-sh/uv) from 0.8.6 to 0.9.5.
- [Release notes](https://github.com/astral-sh/uv/releases)
- [Changelog](https://github.com/astral-sh/uv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/uv/compare/0.8.6...0.9.5)

---
updated-dependencies:
- dependency-name: uv
  dependency-version: 0.9.5
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-21 13:16:30 -07:00
1c7fe8f861 [BugFix] chunk_size should always be int64_t (#165971)
aspired by https://github.com/pytorch/pytorch/pull/156872
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165971
Approved by: https://github.com/albanD
2025-10-21 19:52:47 +00:00
4e643422f6 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-21 19:47:33 +00:00
3c3b278872 [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
2025-10-21 19:43:55 +00:00
19 changed files with 813 additions and 370 deletions

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.8.6"
"uv==0.9.5"
]
[tool.setuptools]

View File

@ -147,15 +147,16 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -58,8 +58,10 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -87,7 +89,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag with retry
- name: Create and push tag(s) with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -112,14 +114,23 @@ jobs:
return 1
}
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
# Retry configuration
MAX_RETRIES=5
@ -194,31 +205,111 @@ jobs:
}
}
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
exit 0
else
echo "Tag creation failed after all retry attempts"
exit 1
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
fi

View File

@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
}
}
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
// Short circuit on empty result
if (result.numel() == 0) {
return result;
}
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
// end Lt path
} else {
// No Lt, we use a GEMM instead
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
}
} // end GEMM path
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (useLtInterface && activation == Activation::GELU) {
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif

View File

@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1
update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()

View File

@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_graph_break(self):
def fn(x):
with torch.fx.traceback.annotate({"pp_stage": 0}):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(10, requires_grad=True)
self.assertEqual(fn(x), opt_fn(x))
if __name__ == "__main__":
run_tests()

View File

@ -6,6 +6,7 @@ import builtins
import collections
import contextlib
import copy
import gc
import functools
import inspect
import io
@ -19,6 +20,7 @@ import traceback
import types
import typing
import unittest
import weakref
import warnings
from math import sqrt
from torch.multiprocessing import Process
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x")

View File

@ -7381,6 +7381,10 @@ torch.cuda.synchronize()
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2

View File

@ -8490,6 +8490,14 @@ class TestNNDeviceType(NNTestCase):
y_cuda_contig = pool(x_cuda.contiguous())
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
@onlyCUDA
def test_large_reflect_pad(self, device):
# https://github.com/pytorch/pytorch/issues/165861
x = torch.rand(2**16, 2, device="cuda")
c = F.pad(x, (1, 1), mode="reflect")
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
self.assertEqual(c, c_cpu)
@onlyCUDA
@largeTensorTest("48GB", "cpu")
@largeTensorTest("48GB", "cuda")

View File

@ -247,6 +247,10 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
@ -576,6 +580,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@ -621,6 +629,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@ -658,6 +670,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@ -692,6 +708,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_id(self, dtype) -> None:
N = 512
torch.manual_seed(0)
@ -729,6 +749,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this
@ -754,6 +778,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -770,6 +798,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -808,6 +840,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype)
@ -843,6 +879,10 @@ class TestSparseSemiStructuredTraining(TestCase):
)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
@ -853,6 +893,10 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)

View File

@ -2758,6 +2758,12 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -2810,5 +2810,15 @@
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
],
"GB0279": [
{
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
"Context": "str(self)",
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
def fn_name(self):
return "annotate"
def reconstruct_type(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.fx.traceback.annotate escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -1,11 +1,15 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -163,7 +167,41 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
};
static PyObject* NodeBase_new(
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -385,41 +385,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
self._prepend(x)
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -430,11 +396,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
self._next._prepend(x)
@property
def args(self) -> tuple[Argument, ...]: