mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
3 Commits
gh/janeyx9
...
cpp-docs-d
Author | SHA1 | Date | |
---|---|---|---|
5b6cc8215f | |||
1c43c9cfd0 | |||
102e0d5437 |
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.1.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.1.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.9.5"
|
||||
"uv==0.8.6"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
echo "======================================"
|
||||
echo "ERROR: $undocumented undocumented objects found!"
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Full coverage report:"
|
||||
cat build/coverage/python.txt
|
||||
echo ""
|
||||
echo "======================================"
|
||||
echo "Undocumented modules/objects (lines after TOTAL):"
|
||||
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
|
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,16 +147,15 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
149
.github/workflows/trunk-tagging.yml
vendored
149
.github/workflows/trunk-tagging.yml
vendored
@ -58,10 +58,8 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -89,7 +87,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag(s) with retry
|
||||
- name: Create and push tag with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -114,23 +112,14 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -205,111 +194,31 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
else
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
else
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
@ -1,378 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<uint##bit##_t> { \
|
||||
using neon_type = uint##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = uint##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(uint##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_u##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<uint##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
uint64_t count = size()); \
|
||||
void store(void* ptr, uint64_t count = size()) const; \
|
||||
template <uint64_t mask> \
|
||||
static Vectorized<uint##bit##_t> blend( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b); \
|
||||
static Vectorized<uint##bit##_t> blendv( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
const Vectorized<uint##bit##_t>& mask_) { \
|
||||
return vbslq_u##bit(mask_.values, b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<uint##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<uint##bit##_t> set( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
uint64_t count = size()); \
|
||||
const uint##bit##_t& operator[](uint idx) const = delete; \
|
||||
uint##bit##_t& operator[](uint idx) = delete; \
|
||||
Vectorized<uint##bit##_t> abs() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> imag() const { \
|
||||
return vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> neg() const { \
|
||||
return vreinterpretq_u##bit##_s##bit( \
|
||||
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
|
||||
} \
|
||||
uint##bit##_t reduce_add() const { \
|
||||
return vaddvq_u##bit(values); \
|
||||
} \
|
||||
uint##bit##_t reduce_max() const; \
|
||||
Vectorized<uint##bit##_t> operator==( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator!=( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> operator<( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator<=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> le( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator+( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vaddq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator-( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vsubq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator&( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vandq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator|( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vorrq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator^( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return veorq_u##bit(a, b); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
|
||||
return vmaxvq_u8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator*(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmulq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
||||
return vmvnq_u8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
|
||||
const Vectorized<uint8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline minimum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vminq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline maximum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmaxq_u8(a, b);
|
||||
}
|
||||
|
||||
template <uint64_t mask>
|
||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_UINT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
|
||||
values = vdupq_n_u##bit(val); \
|
||||
} \
|
||||
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
|
||||
const void* ptr, uint64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ uint##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const uint##bit##_t*>(ptr), \
|
||||
count * sizeof(uint##bit##_t)); \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
uint##bit##_t tmp_values[size()]; \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_OPS(16, 8)
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
|
||||
uint8_t base,
|
||||
step_t step) {
|
||||
const Vectorized<uint8_t> base_vec(base);
|
||||
const Vectorized<uint8_t> step_vec(step);
|
||||
const uint8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_u8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator>>(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator<<(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return vshlq_u8(a, vreinterpretq_s8_u8(z));
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b,
|
||||
uint64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator/(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_max(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_min(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
|
||||
|
@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
|
||||
* Additionally, for ROCM we test whether the architecture supports the Lt.
|
||||
*/
|
||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
// When hipBLASLt is not supported on the architecture, return true
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (env_value == "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
|
||||
if (!is_hipblas_lt_arch_supported) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether it is disabled in the env
|
||||
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (is_addmm_cuda_lt_disabled == "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
const auto args = cublasCommonArgs(mat1, mat2, result);
|
||||
if (args.transa == 't' && args.transb == 't') {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto mat1_sizes = mat1.sizes();
|
||||
const auto mat2_sizes = mat2.sizes();
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
#endif
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16
|
||||
)
|
||||
&& ( // some shape/stride restrictions
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
|
||||
// their row-/col-majorness.
|
||||
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
|
||||
// The last conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
// Related to avoiding the leading stride >> leading dim problematic case
|
||||
// with 16b dtypes described above. For such dtypes we only allow inputs
|
||||
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
|
||||
// In that case the leading stride will be equal to the outer dim len.
|
||||
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
|
||||
// does not modify inputs as long as there is a stride of length 1
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
// check mat1/mat2 is row-/col-major
|
||||
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
|
||||
)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
#endif
|
||||
|
||||
// no compliance by default
|
||||
return false;
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
// TODO: maybe also return some success state?
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self_ptr,
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmCublas(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta
|
||||
) {
|
||||
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
beta.to<at::opmath_type<scalar_t>>(),
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld
|
||||
);
|
||||
return true; // success!
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
|
||||
// Shape checks {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
// expand().
|
||||
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
if (result.is_same(self)) {
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
|
||||
}
|
||||
// } Shape checks
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
IntArrayRef self__sizes;
|
||||
bool useLtInterface = false;
|
||||
#if defined(USE_ROCM)
|
||||
// When hipBLASLt is not supported on the architecture,
|
||||
// disable_addmm_cuda_lt will always be to set to true
|
||||
static bool disable_addmm_cuda_lt =
|
||||
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
|
||||
#else
|
||||
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
|
||||
#endif
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
cublasCommonArgs _args(mat1, mat2, result);
|
||||
if (_args.transa == 't' && _args.transb == 't') {
|
||||
disable_addmm_cuda_lt_final = true;
|
||||
}
|
||||
#endif
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
// for cuda 11.4, cublasLtMatmul is activated
|
||||
// the last two conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
if (!disable_addmm_cuda_lt_final) {
|
||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||
self.is_contiguous() && result.is_contiguous() &&
|
||||
#ifdef USE_ROCM
|
||||
(scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#else
|
||||
(scalar_type == at::ScalarType::Double ||
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
|
||||
#else
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
// avoid leading dim >> rows bugs
|
||||
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
|
||||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16)) &&
|
||||
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
|
||||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
self__sizes = self_->sizes();
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
|
||||
}
|
||||
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
if (&result != &self) {
|
||||
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
|
||||
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
|
||||
at::native::copy_(result, *self_);
|
||||
}
|
||||
}
|
||||
|
||||
// Short circuit on empty result
|
||||
if (result.numel() == 0) {
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Short circuit if the reduction dim is empty
|
||||
if (mat1.sizes()[1] == 0) {
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
|
||||
if (mat1.numel() == 0) {
|
||||
// By definition, when beta==0, values in self should be ignored. nans and infs
|
||||
// should not propagate
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */
|
||||
)
|
||||
);
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */));
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
// The Lt path
|
||||
if (!disable_addmm_cuda_lt) {
|
||||
bool lt_success = false;
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
|
||||
#else
|
||||
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
// !is_float_output_with_half_input
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
|
||||
if (!lt_success) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
// end Lt path
|
||||
} else {
|
||||
// No Lt, we use a GEMM instead
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<float>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
launchGemmCublas<scalar_t, float>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
|
||||
float* result_ptr = args.result->mutable_data_ptr<float>();
|
||||
at::cuda::blas::gemm<scalar_t, float>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
launchGemmCublas<scalar_t>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm<scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
switch (activation) {
|
||||
case Activation::RELU:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
} // end GEMM path
|
||||
}
|
||||
|
||||
// Preprocessor gate here needs to match the inverse of the check
|
||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||
// performing a post-GELU because we weren't able to use the GELU
|
||||
// epilogue above.
|
||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
|
||||
if (useLtInterface && activation == Activation::GELU) {
|
||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||
}
|
||||
#endif
|
||||
|
@ -23,7 +23,7 @@ namespace at::native {
|
||||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
|
@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
|
||||
output_offset + output_y * output_dim_x + output_x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
|
||||
const int64_t two = (len - 1) * 2;
|
||||
if (two <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int64_t m = x % two;
|
||||
if (m < 0) m += two;
|
||||
return (m < len) ? m : (two - m);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reflection_pad1d_out_kernel(
|
||||
const scalar_t * input, scalar_t * output,
|
||||
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_flat(
|
||||
const scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
int64_t input_w, int64_t pad_l, int64_t pad_r,
|
||||
int64_t out_w, int64_t plane_count) {
|
||||
|
||||
const int64_t bx = blockDim.x;
|
||||
const int64_t tx = threadIdx.x;
|
||||
|
||||
const int64_t total = plane_count * out_w;
|
||||
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
|
||||
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
|
||||
|
||||
for (; linear < total; linear += grid_stride) {
|
||||
const int64_t plane = linear / out_w;
|
||||
const int64_t x = linear - plane * out_w;
|
||||
const int64_t j = reflect_index(x - pad_l, input_w);
|
||||
output[plane * out_w + x] = input[plane * input_w + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_backward_out_kernel(
|
||||
scalar_t * grad_input, const scalar_t * grad_output,
|
||||
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
|
||||
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
const int max_x = prop->maxGridSize[0];
|
||||
const int max_y = prop->maxGridSize[1];
|
||||
const int max_z = prop->maxGridSize[2];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
|
||||
|
||||
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
|
||||
|
||||
if (fits3d) {
|
||||
dim3 block(block_x, 1, 1);
|
||||
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
|
||||
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r);
|
||||
} else {
|
||||
dim3 block(block_x, 1, 1);
|
||||
const int64_t plane_count = nplane * nbatch;
|
||||
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
|
||||
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
|
||||
dim3 grid(grid_x, 1, 1);
|
||||
|
||||
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r, output_w, plane_count);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w,
|
||||
pad_l,
|
||||
pad_r);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
|
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int64_t chunk_size,
|
||||
int chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
|
|
@ -48,89 +48,17 @@ PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,Fa
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
|
||||
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
|
||||
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
|
||||
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
|
||||
@ -143,9 +71,6 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
|
||||
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
|
||||
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
|
||||
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
|
||||
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
|
||||
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
|
||||
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
|
||||
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000
|
||||
|
|
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.int32, torch.uint8],
|
||||
"dtype_two": [torch.int32, torch.uint8],
|
||||
"dtype_one": [torch.int32],
|
||||
"dtype_two": [torch.int32],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_two=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_one=[torch.int8, torch.int32],
|
||||
dtype_two=[torch.int8, torch.int32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -207,6 +207,42 @@ templates_path = [
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
@ -253,7 +253,6 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
@ -8,8 +8,7 @@ class TestAutocast(TestCase):
|
||||
def test_autocast_with_unsupported_type(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"In openreg autocast, but the target dtype is not supported."
|
||||
"openreg Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
||||
):
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -67,21 +67,7 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
|
||||
# the input shape was (1, 768), so that the forward gemm used
|
||||
# cublaslt, and the backward used cublas.
|
||||
# With the aforementioned PR, and with shape (1, 768),
|
||||
# the cublas path is used both in forward and in backward,
|
||||
# altering peak memory usage not accounting for cublaslt.
|
||||
# Here we change the input shape to (2, 768), and that swaps
|
||||
# the cublas/cublaslt selection in the forward/backward,
|
||||
# but that does not affect the peak memory usage stored in `base_mem_mb`.
|
||||
# Reasons for the flip:
|
||||
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
|
||||
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
|
||||
# since the input preparation can swap matrices based on output
|
||||
# row-/col-majorness.
|
||||
inp = torch.randn(2, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
|
@ -288,18 +288,6 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||
x = torch.sin(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -6,7 +6,6 @@ import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import functools
|
||||
import inspect
|
||||
import io
|
||||
@ -20,7 +19,6 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
import weakref
|
||||
import warnings
|
||||
from math import sqrt
|
||||
from torch.multiprocessing import Process
|
||||
@ -1626,25 +1624,6 @@ class TestFX(JitTestCase):
|
||||
|
||||
self.assertTrue(neg not in relu.users)
|
||||
|
||||
@skipIfTorchDynamo("Dynamo does not free right away")
|
||||
def test_prepend_does_not_leak(self):
|
||||
g = Graph()
|
||||
x = g.placeholder("x")
|
||||
relu = g.call_function(torch.relu, (x,))
|
||||
neg = g.call_function(torch.neg, (x,))
|
||||
|
||||
relu.prepend(neg)
|
||||
|
||||
ref = weakref.ref(neg)
|
||||
g.erase_node(neg)
|
||||
del g
|
||||
del x
|
||||
del relu
|
||||
del neg
|
||||
gc.collect()
|
||||
|
||||
self.assertIsNone(ref())
|
||||
|
||||
def test_remove_uses_with_custom_filter(self):
|
||||
g: torch.fx.Graph = Graph()
|
||||
x: torch.fx.Node = g.placeholder("x")
|
||||
|
@ -7381,10 +7381,6 @@ torch.cuda.synchronize()
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
@parametrize("use_legacy_api", [True, False])
|
||||
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_dummy_mha_with_nt(self, device, use_legacy_api):
|
||||
bs = 3
|
||||
d1 = 2
|
||||
|
@ -8490,14 +8490,6 @@ class TestNNDeviceType(NNTestCase):
|
||||
y_cuda_contig = pool(x_cuda.contiguous())
|
||||
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
|
||||
|
||||
@onlyCUDA
|
||||
def test_large_reflect_pad(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/165861
|
||||
x = torch.rand(2**16, 2, device="cuda")
|
||||
c = F.pad(x, (1, 1), mode="reflect")
|
||||
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
|
||||
self.assertEqual(c, c_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("48GB", "cpu")
|
||||
@largeTensorTest("48GB", "cuda")
|
||||
|
@ -247,10 +247,6 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
|
||||
@ -580,10 +576,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
@ -629,10 +621,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
@ -670,10 +658,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
@ -708,10 +692,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
@ -749,10 +729,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
@ -778,10 +754,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -798,10 +770,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -840,10 +808,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
@ -879,10 +843,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
@ -893,10 +853,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
|
@ -2758,12 +2758,6 @@ class _NodeBase:
|
||||
return_type: Any,
|
||||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
def _prepend(self, n: FxNode) -> None: ...
|
||||
def _remove_from_list(self) -> None: ...
|
||||
def __lt__(self, n: Self) -> _bool: ...
|
||||
def __gt__(self, n: Self) -> _bool: ...
|
||||
def __le__(self, n: Self) -> _bool: ...
|
||||
def __ge__(self, n: Self) -> _bool: ...
|
||||
|
||||
class _NodeIter(Iterator[FxNode]):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
|
@ -2810,15 +2810,5 @@
|
||||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0279": [
|
||||
{
|
||||
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
|
||||
"Context": "str(self)",
|
||||
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -1295,16 +1295,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
unimplemented_v2(
|
||||
gb_type="torch.fx.traceback.annotate escaped from compiled region",
|
||||
context=str(self),
|
||||
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
@ -3147,82 +3147,35 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
for i, x in enumerate(input_nodes)
|
||||
}
|
||||
example_inputs = list(unique_example_inputs.values())
|
||||
example_inputs_extern = []
|
||||
for input_node in input_nodes:
|
||||
if unique_example_inputs[input_node.get_name()].is_mkldnn:
|
||||
example_inputs_extern.append(
|
||||
unique_example_inputs[input_node.get_name()]
|
||||
)
|
||||
else:
|
||||
base = unique_example_inputs[input_node.get_name()]
|
||||
base = base if base._base is None else base._base
|
||||
sizes = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
example_inputs_extern = [
|
||||
(
|
||||
unique_example_inputs[input_node.get_name()]
|
||||
if unique_example_inputs[input_node.get_name()].is_mkldnn
|
||||
else torch.as_strided(
|
||||
unique_example_inputs[input_node.get_name()],
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for size in input_node.get_size()
|
||||
)
|
||||
strides = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
stride,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for stride in input_node.get_stride()
|
||||
)
|
||||
storage_offset = V.graph.sizevars.atomically_apply_size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# Check if the required storage size exceeds the current storage
|
||||
# to avoid illegal memory access
|
||||
needed_size = torch._prims_common.compute_required_storage_length(
|
||||
sizes, strides, storage_offset
|
||||
)
|
||||
current_size = base.storage().size()
|
||||
|
||||
if needed_size > current_size:
|
||||
# Create a new base tensor with sufficient storage
|
||||
new_base = torch.randn(
|
||||
needed_size,
|
||||
dtype=base.dtype,
|
||||
device=base.device,
|
||||
requires_grad=base.requires_grad,
|
||||
)
|
||||
base = new_base.as_strided(
|
||||
base.size(), base.stride(), base.storage_offset()
|
||||
)
|
||||
|
||||
example_inputs_extern.append(
|
||||
torch.as_strided(base, sizes, strides, storage_offset)
|
||||
),
|
||||
V.graph.sizevars.size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
for input_node in input_nodes
|
||||
]
|
||||
out = cls.benchmark_example_value(layout, hint_override=hint_override)
|
||||
|
||||
# Also check the output tensor for storage size
|
||||
out_base = out if out._base is None else out._base
|
||||
out_offset = V.graph.sizevars.size_hint(layout.offset)
|
||||
needed_out_size = torch._prims_common.compute_required_storage_length(
|
||||
out.size(), out.stride(), out_offset
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
current_out_size = out_base.storage().size()
|
||||
|
||||
if needed_out_size > current_out_size:
|
||||
# Create a new base tensor with sufficient storage
|
||||
new_out_base = torch.randn(
|
||||
needed_out_size,
|
||||
dtype=out_base.dtype,
|
||||
device=out_base.device,
|
||||
requires_grad=out_base.requires_grad,
|
||||
)
|
||||
out_base = new_out_base.as_strided(
|
||||
out_base.size(), out_base.stride(), out_base.storage_offset()
|
||||
)
|
||||
|
||||
out_extern = torch.as_strided(out_base, out.size(), out.stride(), out_offset)
|
||||
expected = None
|
||||
if VERIFY:
|
||||
choices[0].benchmark(*example_inputs_extern, out=out_extern)
|
||||
@ -3663,13 +3616,10 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# So we need call as_strided in the end to 'view' the tensor with the correct
|
||||
# sizes/strides
|
||||
return AlgorithmSelectorCache.generate_example_value(
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for size in node.get_size()
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
@ -3682,20 +3632,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
node.layout.offset,
|
||||
node.layout.offset,
|
||||
V.graph.sizevars.size_hints(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
V.graph.get_allocation_size(node),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for size in V.graph.get_allocation_size(node)
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -230,9 +230,9 @@ class autocast:
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
self.fast_dtype = (
|
||||
torch.get_autocast_dtype(device_type) if dtype is None else dtype
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
self.fast_dtype = dtype
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
@ -243,9 +243,6 @@ class autocast:
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
|
||||
device_supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == self.custom_backend_name:
|
||||
necessary_funcs = [
|
||||
@ -262,55 +259,110 @@ class autocast:
|
||||
assert hasattr(self.custom_device_mod, func), (
|
||||
message + f"But the func `{func}` is missing. \n"
|
||||
)
|
||||
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
|
||||
|
||||
self._cache_enabled = (
|
||||
torch.is_autocast_cache_enabled()
|
||||
if cache_enabled is None
|
||||
else cache_enabled
|
||||
)
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if (
|
||||
enabled
|
||||
and self.device == "cuda"
|
||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
||||
):
|
||||
warnings.warn(
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
|
||||
device_name = (
|
||||
self.device
|
||||
if self.device == self.custom_backend_name
|
||||
else self.device.upper()
|
||||
)
|
||||
if enabled:
|
||||
# Special case for CUDA AMP and bfloat16 support
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.amp.common.amp_definitely_not_available():
|
||||
warnings.warn(
|
||||
"CUDA is not available or torch_xla is imported. Disabling autocast."
|
||||
)
|
||||
enabled = False
|
||||
elif (
|
||||
self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.fast_dtype not in device_supported_dtypes:
|
||||
error_message = (
|
||||
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
f"{device_name} Autocast only supports dtypes of "
|
||||
+ ", ".join(map(str, device_supported_dtypes))
|
||||
+ " currently."
|
||||
if self.device == "cpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype and enabled:
|
||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "CPU Autocast only supports dtype of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
# Special case for MPS bfloat16 support on macOS < 14
|
||||
if (
|
||||
self.device == "mps"
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
):
|
||||
elif self.device == "mtia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "maia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "ipu":
|
||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtypes:
|
||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "hpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "cuda":
|
||||
if (
|
||||
enabled
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.fast_dtype == torch.bfloat16:
|
||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||
"on macOS versions below 14. Disabling autocast."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.float16, torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -419,7 +419,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
}
|
||||
|
||||
// Do not call this directly, use ProcessGroup::setGroupName instead.
|
||||
virtual void setGroupUid(const std::string& pg_uid) {
|
||||
void setGroupUid(const std::string& pg_uid) {
|
||||
pg_uid_ = pg_uid;
|
||||
}
|
||||
|
||||
|
@ -1,15 +1,11 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||
struct NodeBase;
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
@ -167,41 +163,7 @@ struct NodeBase {
|
||||
PyObject* users;
|
||||
PyObject* _repr_fn;
|
||||
PyObject* meta;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||
|
||||
inline NodeSortKey& sort_key() {
|
||||
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||
}
|
||||
|
||||
inline void set_prev(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _prev;
|
||||
_prev = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
inline void set_next(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _next;
|
||||
_next = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
// Equivalent to:
|
||||
// p, n = self._prev, self._next
|
||||
// p._next, n._prev = n, p
|
||||
inline void remove_from_list() {
|
||||
if (this->_prev == this && this->_next == this) {
|
||||
return;
|
||||
}
|
||||
NodeBase* p = this->_prev;
|
||||
NodeBase* n = this->_next;
|
||||
p->set_next(n);
|
||||
n->set_prev(p);
|
||||
}
|
||||
PyObject* _sort_key;
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
@ -211,8 +173,6 @@ static PyObject* NodeBase_new(
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||
NodeSortKey(); // placement new does not allocate
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -241,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->users = PyDict_New();
|
||||
self->_repr_fn = Py_NewRef(Py_None);
|
||||
self->meta = PyDict_New();
|
||||
self->_sort_key = PyTuple_New(0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -260,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
@ -277,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->users);
|
||||
Py_VISIT(self->_repr_fn);
|
||||
Py_VISIT(self->meta);
|
||||
Py_VISIT(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -294,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->users);
|
||||
Py_CLEAR(self->_repr_fn);
|
||||
Py_CLEAR(self->meta);
|
||||
Py_CLEAR(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
@ -358,195 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__remove_from_list(
|
||||
PyObject* self,
|
||||
PyObject* _ignored) {
|
||||
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||
if (self_ == arg) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
if (!is_node(arg)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||
return nullptr;
|
||||
}
|
||||
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||
if (self->graph != x->graph) {
|
||||
PyErr_SetString(
|
||||
PyExc_AssertionError,
|
||||
"Attempting to move a Node into a different Graph");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
x->remove_from_list();
|
||||
NodeBase* p = self->_prev;
|
||||
p->set_next(x);
|
||||
x->set_prev(p);
|
||||
x->set_next(self);
|
||||
self->set_prev(x);
|
||||
|
||||
// Now compute x.sort_key()
|
||||
const NodeSortKey& psk = x->_prev->sort_key();
|
||||
const NodeSortKey& nsk = x->_next->sort_key();
|
||||
if (psk.size() > nsk.size()) {
|
||||
// prefix = psk[: len(nsk)+1]
|
||||
size_t slice_len = nsk.size() + 1;
|
||||
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||
// last element is idx => increment by 1
|
||||
prefix.back()++;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else if (psk.size() < nsk.size()) {
|
||||
// prefix = nsk[: len(psk)+1]
|
||||
size_t slice_len = psk.size() + 1;
|
||||
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||
// last element is idx => decrement by 1
|
||||
prefix.back()--;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else {
|
||||
// same length => add a 0
|
||||
x->sort_key() = psk;
|
||||
x->sort_key().emplace_back(0);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||
// METH_O => one argument: 'other'
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
bool less = std::lexicographical_compare(
|
||||
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||
if (less)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
// "a > b" is equivalent to "b < a"
|
||||
bool greater = std::lexicographical_compare(
|
||||
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||
if (greater)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___gt__(self, other);
|
||||
}
|
||||
|
||||
// __le__(self, other): Return not (self > other)
|
||||
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___lt__(self, other);
|
||||
}
|
||||
|
||||
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||
// Only used by pickle/__getstate__
|
||||
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
const NodeSortKey& vec = node->sort_key();
|
||||
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||
THPObjectPtr tuple(PyTuple_New(n));
|
||||
if (!tuple) {
|
||||
return nullptr; // Out of memory
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* value = PyLong_FromSsize_t(vec[i]);
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), i, value);
|
||||
}
|
||||
return tuple.release();
|
||||
}
|
||||
|
||||
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||
static int NodeBase_set_sort_key(
|
||||
PyObject* self,
|
||||
PyObject* value,
|
||||
void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
if (!PyTuple_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||
NodeSortKey new_vec;
|
||||
new_vec.reserve(size);
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||
if (val == -1 && PyErr_Occurred()) {
|
||||
return -1;
|
||||
}
|
||||
new_vec.emplace_back(val);
|
||||
}
|
||||
node->sort_key() = std::move(new_vec);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef NodeBase_methods[] = {
|
||||
{"_update_args_kwargs",
|
||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||
METH_FASTCALL,
|
||||
"Internal method: do not call directly."},
|
||||
{"_remove_from_list",
|
||||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||
METH_NOARGS,
|
||||
"Internal method: do not call directly."},
|
||||
{"_prepend",
|
||||
(PyCFunction)(void*)(NodeBase__prepend),
|
||||
METH_O,
|
||||
"Internal method: do not call directly."},
|
||||
{"__lt__",
|
||||
(PyCFunction)(void*)NodeBase___lt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key < other.sort_key"},
|
||||
{"__gt__",
|
||||
(PyCFunction)(void*)NodeBase___gt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key > other.sort_key"},
|
||||
{"__ge__",
|
||||
(PyCFunction)(void*)NodeBase___ge__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key >= other.sort_key"},
|
||||
{"__le__",
|
||||
(PyCFunction)(void*)NodeBase___le__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key <= other.sort_key"},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyGetSetDef NodeBase_getset[] = {
|
||||
{"_sort_key", // attribute name in Python
|
||||
(getter)NodeBase_get_sort_key, // C getter function
|
||||
(setter)NodeBase_set_sort_key, // C setter function
|
||||
(char*)"The sort key as a tuple of ints", // docstring
|
||||
nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
@ -578,7 +361,7 @@ PyTypeObject NodeBaseType = {
|
||||
nullptr, /* tp_iternext */
|
||||
NodeBase_methods, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
NodeBase_getset, /* tp_getset */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
|
@ -1,12 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable, accelerator)
|
||||
namespace torch::stable::accelerator {
|
||||
|
||||
using DeleterFnPtr = void (*)(void*);
|
||||
|
||||
@ -76,4 +75,4 @@ inline DeviceIndex getCurrentDeviceIndex() {
|
||||
return device_index;
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable, accelerator)
|
||||
} // namespace torch::stable::accelerator
|
||||
|
@ -9,9 +9,8 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
namespace torch::stable {
|
||||
|
||||
// We expect this to be the stable version of the empty_like op that takes in
|
||||
// no kwargs (device, dtype, layout, memory_format). We will add kwargs
|
||||
@ -245,4 +244,4 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable)
|
||||
} // namespace torch::stable
|
||||
|
@ -4,13 +4,12 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
|
||||
namespace torch::stable::detail {
|
||||
|
||||
// forward declare so that the from/to() implementations in the detail
|
||||
// namespace of library.h where the real work is done can compile.
|
||||
@ -336,7 +335,7 @@ inline T to(StableIValue val) {
|
||||
return detail::ToImpl<T>::call(val);
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable, detail)
|
||||
} // namespace torch::stable::detail
|
||||
|
||||
// [global from/to deprecation note]
|
||||
// WARNING! the following APIs will be removed!! We deprecated global from/to
|
||||
|
@ -8,10 +8,9 @@
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
namespace torch::stable {
|
||||
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
@ -22,4 +21,4 @@ inline ScalarType Tensor::scalar_type() const {
|
||||
torch::stable::detail::from(dtype));
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable)
|
||||
} // namespace torch::stable
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
@ -10,7 +9,7 @@
|
||||
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
namespace torch::stable {
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::ScalarType;
|
||||
@ -169,4 +168,4 @@ class Tensor {
|
||||
// =============================================================================
|
||||
};
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable)
|
||||
} // namespace torch::stable
|
||||
|
@ -385,7 +385,41 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._prepend(x)
|
||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
||||
if self == x:
|
||||
log.debug(
|
||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
||||
)
|
||||
return
|
||||
x._remove_from_list()
|
||||
p = self._prev
|
||||
p._next, x._prev = x, p
|
||||
x._next, self._prev = self, x
|
||||
|
||||
# compute x._sort_key
|
||||
psk = x._prev._sort_key
|
||||
nsk = x._next._sort_key
|
||||
if len(psk) > len(nsk):
|
||||
idx: int
|
||||
*prefix, idx = psk[: len(nsk) + 1]
|
||||
x._sort_key = (*prefix, idx + 1)
|
||||
elif len(psk) < len(nsk):
|
||||
*prefix, idx = nsk[: len(psk) + 1]
|
||||
x._sort_key = (*prefix, idx - 1)
|
||||
else: # same length, increase length by 1
|
||||
x._sort_key = (*psk, 0)
|
||||
|
||||
def __gt__(self, other: "Node") -> bool:
|
||||
return self._sort_key > other._sort_key
|
||||
|
||||
def __lt__(self, other: "Node") -> bool:
|
||||
return self._sort_key < other._sort_key
|
||||
|
||||
def __ge__(self, other: "Node") -> bool:
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other: "Node") -> bool:
|
||||
return self < other or self == other
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def append(self, x: "Node") -> None:
|
||||
@ -396,7 +430,11 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._next._prepend(x)
|
||||
self._next.prepend(x)
|
||||
|
||||
def _remove_from_list(self) -> None:
|
||||
p, n = self._prev, self._next
|
||||
p._next, n._prev = n, p
|
||||
|
||||
@property
|
||||
def args(self) -> tuple[Argument, ...]:
|
||||
|
@ -611,52 +611,4 @@ __host__ __device__
|
||||
#define C10_RETURN_MOVE_IF_OLD_COMPILER 0
|
||||
#endif
|
||||
|
||||
// The HIDDEN_NAMESPACE_BEGIN and HIDDEN_NAMESPACE_END below
|
||||
// are needed for maintaining robustness in our header APIs in
|
||||
// torch/headeronly and torch/csrc/stable under the namespaces
|
||||
// torch::headeronly and torch::stable respectively. We enforce
|
||||
// hidden visibility for these APIs because we want to enable
|
||||
// loading custom extensions compiled against different libtorch
|
||||
// versions where these APIs may have changed.
|
||||
|
||||
// Helper macros for nested namespace expansion
|
||||
#define _HIDDEN_NS_EXPAND(...) __VA_ARGS__
|
||||
#define _HIDDEN_NS_GET_MACRO(_1, _2, _3, NAME, ...) NAME
|
||||
|
||||
// Macros to handle 1-3 namespace levels
|
||||
#define _HIDDEN_NS_1(n1) namespace n1 __attribute__((visibility("hidden"))) {
|
||||
#define _HIDDEN_NS_2(n1, n2) \
|
||||
namespace n1 { \
|
||||
namespace n2 __attribute__((visibility("hidden"))) {
|
||||
#define _HIDDEN_NS_3(n1, n2, n3) \
|
||||
namespace n1::n2 { \
|
||||
namespace n3 __attribute__((visibility("hidden"))) {
|
||||
|
||||
// Macros to close namespaces
|
||||
#define _HIDDEN_NS_END_1(n1) }
|
||||
#define _HIDDEN_NS_END_N(n1, ...) \
|
||||
} \
|
||||
}
|
||||
|
||||
#if !defined(HIDDEN_NAMESPACE_BEGIN)
|
||||
#if defined(__GNUG__) && !defined(_WIN32)
|
||||
#define HIDDEN_NAMESPACE_BEGIN(...) \
|
||||
_HIDDEN_NS_EXPAND(_HIDDEN_NS_GET_MACRO( \
|
||||
__VA_ARGS__, _HIDDEN_NS_3, _HIDDEN_NS_2, _HIDDEN_NS_1)(__VA_ARGS__))
|
||||
#else
|
||||
#define HIDDEN_NAMESPACE_BEGIN(...) namespace __VA_ARGS__ {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if !defined(HIDDEN_NAMESPACE_END)
|
||||
#if defined(__GNUG__) && !defined(_WIN32)
|
||||
#define HIDDEN_NAMESPACE_END(...) \
|
||||
_HIDDEN_NS_EXPAND(_HIDDEN_NS_GET_MACRO( \
|
||||
__VA_ARGS__, _HIDDEN_NS_END_N, _HIDDEN_NS_END_N, _HIDDEN_NS_END_1)( \
|
||||
__VA_ARGS__))
|
||||
#else
|
||||
#define HIDDEN_NAMESPACE_END(...) }
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif // C10_MACROS_MACROS_H_
|
||||
|
Reference in New Issue
Block a user