mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Compare commits
3 Commits
ciflow/ind
...
cpp-docs-d
Author | SHA1 | Date | |
---|---|---|---|
5b6cc8215f | |||
1c43c9cfd0 | |||
102e0d5437 |
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.1.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.1.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.9.5"
|
||||
"uv==0.8.6"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
echo "======================================"
|
||||
echo "ERROR: $undocumented undocumented objects found!"
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Full coverage report:"
|
||||
cat build/coverage/python.txt
|
||||
echo ""
|
||||
echo "======================================"
|
||||
echo "Undocumented modules/objects (lines after TOTAL):"
|
||||
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
|
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,16 +147,15 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
149
.github/workflows/trunk-tagging.yml
vendored
149
.github/workflows/trunk-tagging.yml
vendored
@ -58,10 +58,8 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -89,7 +87,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag(s) with retry
|
||||
- name: Create and push tag with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -114,23 +112,14 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -205,111 +194,31 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
else
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
else
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
@ -1,378 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<uint##bit##_t> { \
|
||||
using neon_type = uint##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = uint##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(uint##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_u##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<uint##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
uint64_t count = size()); \
|
||||
void store(void* ptr, uint64_t count = size()) const; \
|
||||
template <uint64_t mask> \
|
||||
static Vectorized<uint##bit##_t> blend( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b); \
|
||||
static Vectorized<uint##bit##_t> blendv( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
const Vectorized<uint##bit##_t>& mask_) { \
|
||||
return vbslq_u##bit(mask_.values, b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<uint##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<uint##bit##_t> set( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b, \
|
||||
uint64_t count = size()); \
|
||||
const uint##bit##_t& operator[](uint idx) const = delete; \
|
||||
uint##bit##_t& operator[](uint idx) = delete; \
|
||||
Vectorized<uint##bit##_t> abs() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> imag() const { \
|
||||
return vdupq_n_u##bit(0); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> neg() const { \
|
||||
return vreinterpretq_u##bit##_s##bit( \
|
||||
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
|
||||
} \
|
||||
uint##bit##_t reduce_add() const { \
|
||||
return vaddvq_u##bit(values); \
|
||||
} \
|
||||
uint##bit##_t reduce_max() const; \
|
||||
Vectorized<uint##bit##_t> operator==( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator!=( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> operator<( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator<=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> operator>=( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
Vectorized<uint##bit##_t> le( \
|
||||
const Vectorized<uint##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator+( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vaddq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator-( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vsubq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator&( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vandq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator|( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return vorrq_u##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<uint##bit##_t> inline operator^( \
|
||||
const Vectorized<uint##bit##_t>& a, \
|
||||
const Vectorized<uint##bit##_t>& b) { \
|
||||
return veorq_u##bit(a, b); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<uint##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
|
||||
const Vectorized<uint##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
|
||||
return vmaxvq_u8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator*(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmulq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
||||
return vmvnq_u8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
|
||||
const Vectorized<uint8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline minimum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vminq_u8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline maximum(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
return vmaxq_u8(a, b);
|
||||
}
|
||||
|
||||
template <uint64_t mask>
|
||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_UINT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
|
||||
values = vdupq_n_u##bit(val); \
|
||||
} \
|
||||
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
|
||||
const void* ptr, uint64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ uint##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const uint##bit##_t*>(ptr), \
|
||||
count * sizeof(uint##bit##_t)); \
|
||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
uint##bit##_t tmp_values[size()]; \
|
||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_UINT_NEON_OPS(16, 8)
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
|
||||
uint8_t base,
|
||||
step_t step) {
|
||||
const Vectorized<uint8_t> base_vec(base);
|
||||
const Vectorized<uint8_t> step_vec(step);
|
||||
const uint8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_u8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator>>(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator<<(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
uint8x16_t z = vminq_u8(b, bound);
|
||||
return vshlq_u8(a, vreinterpretq_s8_u8(z));
|
||||
}
|
||||
|
||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b,
|
||||
uint64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_u8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline operator/(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& b) {
|
||||
uint8x16_t x = a;
|
||||
uint8x16_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_max(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<uint8_t> inline clamp_min(
|
||||
const Vectorized<uint8_t>& a,
|
||||
const Vectorized<uint8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<uint8_t> src) {
|
||||
auto u8x8 = vget_low_u8(src);
|
||||
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||
auto u16x8 = vmovl_u8(u8x8);
|
||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||
|
||||
|
@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
|
||||
* Additionally, for ROCM we test whether the architecture supports the Lt.
|
||||
*/
|
||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
// When hipBLASLt is not supported on the architecture, return true
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (env_value == "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
|
||||
if (!is_hipblas_lt_arch_supported) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether it is disabled in the env
|
||||
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (is_addmm_cuda_lt_disabled == "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
const auto args = cublasCommonArgs(mat1, mat2, result);
|
||||
if (args.transa == 't' && args.transb == 't') {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto mat1_sizes = mat1.sizes();
|
||||
const auto mat2_sizes = mat2.sizes();
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
#endif
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16
|
||||
)
|
||||
&& ( // some shape/stride restrictions
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
|
||||
// their row-/col-majorness.
|
||||
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
|
||||
// The last conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
// Related to avoiding the leading stride >> leading dim problematic case
|
||||
// with 16b dtypes described above. For such dtypes we only allow inputs
|
||||
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
|
||||
// In that case the leading stride will be equal to the outer dim len.
|
||||
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
|
||||
// does not modify inputs as long as there is a stride of length 1
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
// check mat1/mat2 is row-/col-major
|
||||
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
|
||||
)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
#endif
|
||||
|
||||
// no compliance by default
|
||||
return false;
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
// TODO: maybe also return some success state?
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self_ptr,
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmCublas(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta
|
||||
) {
|
||||
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
beta.to<at::opmath_type<scalar_t>>(),
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld
|
||||
);
|
||||
return true; // success!
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
|
||||
// Shape checks {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
// expand().
|
||||
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
if (result.is_same(self)) {
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
|
||||
}
|
||||
// } Shape checks
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
IntArrayRef self__sizes;
|
||||
bool useLtInterface = false;
|
||||
#if defined(USE_ROCM)
|
||||
// When hipBLASLt is not supported on the architecture,
|
||||
// disable_addmm_cuda_lt will always be to set to true
|
||||
static bool disable_addmm_cuda_lt =
|
||||
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
|
||||
#else
|
||||
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
|
||||
#endif
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
cublasCommonArgs _args(mat1, mat2, result);
|
||||
if (_args.transa == 't' && _args.transb == 't') {
|
||||
disable_addmm_cuda_lt_final = true;
|
||||
}
|
||||
#endif
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
// for cuda 11.4, cublasLtMatmul is activated
|
||||
// the last two conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
if (!disable_addmm_cuda_lt_final) {
|
||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||
self.is_contiguous() && result.is_contiguous() &&
|
||||
#ifdef USE_ROCM
|
||||
(scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#else
|
||||
(scalar_type == at::ScalarType::Double ||
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
|
||||
#else
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
// avoid leading dim >> rows bugs
|
||||
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
|
||||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16)) &&
|
||||
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
|
||||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
self__sizes = self_->sizes();
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
|
||||
}
|
||||
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
if (&result != &self) {
|
||||
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
|
||||
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
|
||||
at::native::copy_(result, *self_);
|
||||
}
|
||||
}
|
||||
|
||||
// Short circuit on empty result
|
||||
if (result.numel() == 0) {
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Short circuit if the reduction dim is empty
|
||||
if (mat1.sizes()[1] == 0) {
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
|
||||
if (mat1.numel() == 0) {
|
||||
// By definition, when beta==0, values in self should be ignored. nans and infs
|
||||
// should not propagate
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */
|
||||
)
|
||||
);
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */));
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
// The Lt path
|
||||
if (!disable_addmm_cuda_lt) {
|
||||
bool lt_success = false;
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
|
||||
#else
|
||||
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
// !is_float_output_with_half_input
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
|
||||
if (!lt_success) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
// end Lt path
|
||||
} else {
|
||||
// No Lt, we use a GEMM instead
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<float>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
launchGemmCublas<scalar_t, float>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
|
||||
float* result_ptr = args.result->mutable_data_ptr<float>();
|
||||
at::cuda::blas::gemm<scalar_t, float>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
launchGemmCublas<scalar_t>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm<scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
switch (activation) {
|
||||
case Activation::RELU:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
} // end GEMM path
|
||||
}
|
||||
|
||||
// Preprocessor gate here needs to match the inverse of the check
|
||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||
// performing a post-GELU because we weren't able to use the GELU
|
||||
// epilogue above.
|
||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
|
||||
if (useLtInterface && activation == Activation::GELU) {
|
||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||
}
|
||||
#endif
|
||||
|
@ -23,7 +23,7 @@ namespace at::native {
|
||||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
|
@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
|
||||
output_offset + output_y * output_dim_x + output_x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
|
||||
const int64_t two = (len - 1) * 2;
|
||||
if (two <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int64_t m = x % two;
|
||||
if (m < 0) m += two;
|
||||
return (m < len) ? m : (two - m);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reflection_pad1d_out_kernel(
|
||||
const scalar_t * input, scalar_t * output,
|
||||
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_flat(
|
||||
const scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
int64_t input_w, int64_t pad_l, int64_t pad_r,
|
||||
int64_t out_w, int64_t plane_count) {
|
||||
|
||||
const int64_t bx = blockDim.x;
|
||||
const int64_t tx = threadIdx.x;
|
||||
|
||||
const int64_t total = plane_count * out_w;
|
||||
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
|
||||
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
|
||||
|
||||
for (; linear < total; linear += grid_stride) {
|
||||
const int64_t plane = linear / out_w;
|
||||
const int64_t x = linear - plane * out_w;
|
||||
const int64_t j = reflect_index(x - pad_l, input_w);
|
||||
output[plane * out_w + x] = input[plane * input_w + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_backward_out_kernel(
|
||||
scalar_t * grad_input, const scalar_t * grad_output,
|
||||
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
|
||||
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
const int max_x = prop->maxGridSize[0];
|
||||
const int max_y = prop->maxGridSize[1];
|
||||
const int max_z = prop->maxGridSize[2];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
|
||||
|
||||
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
|
||||
|
||||
if (fits3d) {
|
||||
dim3 block(block_x, 1, 1);
|
||||
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
|
||||
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r);
|
||||
} else {
|
||||
dim3 block(block_x, 1, 1);
|
||||
const int64_t plane_count = nplane * nbatch;
|
||||
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
|
||||
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
|
||||
dim3 grid(grid_x, 1, 1);
|
||||
|
||||
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r, output_w, plane_count);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w,
|
||||
pad_l,
|
||||
pad_r);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
|
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int64_t chunk_size,
|
||||
int chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
|
|
@ -48,89 +48,17 @@ PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,Fa
|
||||
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
|
||||
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
|
||||
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
|
||||
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
|
||||
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
|
||||
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
|
||||
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
|
||||
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
|
||||
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
|
||||
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
|
||||
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
|
||||
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
|
||||
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
|
||||
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
|
||||
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
|
||||
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
|
||||
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
|
||||
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
|
||||
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
|
||||
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
|
||||
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
|
||||
@ -143,9 +71,6 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
|
||||
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
|
||||
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
|
||||
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
|
||||
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
|
||||
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
|
||||
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
|
||||
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
|
||||
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000
|
||||
|
|
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.int32, torch.uint8],
|
||||
"dtype_two": [torch.int32, torch.uint8],
|
||||
"dtype_one": [torch.int32],
|
||||
"dtype_two": [torch.int32],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_two=[torch.int8, torch.int32, torch.uint8],
|
||||
dtype_one=[torch.int8, torch.int32],
|
||||
dtype_two=[torch.int8, torch.int32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -207,6 +207,42 @@ templates_path = [
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
@ -253,7 +253,6 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
@ -8,8 +8,7 @@ class TestAutocast(TestCase):
|
||||
def test_autocast_with_unsupported_type(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"In openreg autocast, but the target dtype is not supported."
|
||||
"openreg Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
||||
):
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -67,21 +67,7 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
|
||||
# the input shape was (1, 768), so that the forward gemm used
|
||||
# cublaslt, and the backward used cublas.
|
||||
# With the aforementioned PR, and with shape (1, 768),
|
||||
# the cublas path is used both in forward and in backward,
|
||||
# altering peak memory usage not accounting for cublaslt.
|
||||
# Here we change the input shape to (2, 768), and that swaps
|
||||
# the cublas/cublaslt selection in the forward/backward,
|
||||
# but that does not affect the peak memory usage stored in `base_mem_mb`.
|
||||
# Reasons for the flip:
|
||||
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
|
||||
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
|
||||
# since the input preparation can swap matrices based on output
|
||||
# row-/col-majorness.
|
||||
inp = torch.randn(2, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
|
@ -288,18 +288,6 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||
x = torch.sin(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -6,7 +6,6 @@ import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import functools
|
||||
import inspect
|
||||
import io
|
||||
@ -20,7 +19,6 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
import weakref
|
||||
import warnings
|
||||
from math import sqrt
|
||||
from torch.multiprocessing import Process
|
||||
@ -1626,25 +1624,6 @@ class TestFX(JitTestCase):
|
||||
|
||||
self.assertTrue(neg not in relu.users)
|
||||
|
||||
@skipIfTorchDynamo("Dynamo does not free right away")
|
||||
def test_prepend_does_not_leak(self):
|
||||
g = Graph()
|
||||
x = g.placeholder("x")
|
||||
relu = g.call_function(torch.relu, (x,))
|
||||
neg = g.call_function(torch.neg, (x,))
|
||||
|
||||
relu.prepend(neg)
|
||||
|
||||
ref = weakref.ref(neg)
|
||||
g.erase_node(neg)
|
||||
del g
|
||||
del x
|
||||
del relu
|
||||
del neg
|
||||
gc.collect()
|
||||
|
||||
self.assertIsNone(ref())
|
||||
|
||||
def test_remove_uses_with_custom_filter(self):
|
||||
g: torch.fx.Graph = Graph()
|
||||
x: torch.fx.Node = g.placeholder("x")
|
||||
|
@ -7381,10 +7381,6 @@ torch.cuda.synchronize()
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
@parametrize("use_legacy_api", [True, False])
|
||||
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_dummy_mha_with_nt(self, device, use_legacy_api):
|
||||
bs = 3
|
||||
d1 = 2
|
||||
|
@ -8490,14 +8490,6 @@ class TestNNDeviceType(NNTestCase):
|
||||
y_cuda_contig = pool(x_cuda.contiguous())
|
||||
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
|
||||
|
||||
@onlyCUDA
|
||||
def test_large_reflect_pad(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/165861
|
||||
x = torch.rand(2**16, 2, device="cuda")
|
||||
c = F.pad(x, (1, 1), mode="reflect")
|
||||
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
|
||||
self.assertEqual(c, c_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("48GB", "cpu")
|
||||
@largeTensorTest("48GB", "cuda")
|
||||
|
@ -247,10 +247,6 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
|
||||
@ -580,10 +576,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
@ -629,10 +621,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
@ -670,10 +658,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
@ -708,10 +692,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
@ -749,10 +729,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
@ -778,10 +754,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -798,10 +770,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -840,10 +808,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
@ -879,10 +843,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
@ -893,10 +853,6 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
|
@ -2758,12 +2758,6 @@ class _NodeBase:
|
||||
return_type: Any,
|
||||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
def _prepend(self, n: FxNode) -> None: ...
|
||||
def _remove_from_list(self) -> None: ...
|
||||
def __lt__(self, n: Self) -> _bool: ...
|
||||
def __gt__(self, n: Self) -> _bool: ...
|
||||
def __le__(self, n: Self) -> _bool: ...
|
||||
def __ge__(self, n: Self) -> _bool: ...
|
||||
|
||||
class _NodeIter(Iterator[FxNode]):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
|
@ -2810,15 +2810,5 @@
|
||||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0279": [
|
||||
{
|
||||
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
|
||||
"Context": "str(self)",
|
||||
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -3502,8 +3502,10 @@ class InstructionTranslatorBase(
|
||||
if isinstance(excp, Unsupported):
|
||||
excp.remove_from_stats()
|
||||
self.push(
|
||||
VariableTracker.build(self, impl_CONTAINS_OP_fallback).call_function(
|
||||
self, [left, right], {}
|
||||
self.inline_user_function_return(
|
||||
VariableTracker.build(self, impl_CONTAINS_OP_fallback),
|
||||
[left, right],
|
||||
{},
|
||||
)
|
||||
)
|
||||
if op == 1:
|
||||
|
@ -745,9 +745,9 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def handler(tx, a, b):
|
||||
return VariableTracker.build(
|
||||
tx, polyfill_fn_mapping[op]
|
||||
).call_function(tx, [a, b], {})
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
|
||||
)
|
||||
|
||||
result.append(((VariableTracker, VariableTracker), handler))
|
||||
return result
|
||||
@ -1559,18 +1559,19 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
else:
|
||||
# Overrides for custom str method
|
||||
# Pass method as function to call tx.inline_user_function_return
|
||||
bound_method = str_method.__func__ # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
# Only supports certain function types
|
||||
user_func_variable = VariableTracker.build(tx, bound_method)
|
||||
user_func_variable = variables.UserFunctionVariable(bound_method)
|
||||
except AssertionError:
|
||||
# Won't be able to do inline the str method, return to avoid graph break
|
||||
log.warning("Failed to create UserFunctionVariable", exc_info=True)
|
||||
return
|
||||
|
||||
# Inline the user function
|
||||
return user_func_variable.call_function(tx, [arg], {})
|
||||
return tx.inline_user_function_return(user_func_variable, [arg], {})
|
||||
elif isinstance(arg, (variables.ExceptionVariable,)):
|
||||
if len(arg.args) == 0:
|
||||
value = f"{arg.exc_type}"
|
||||
@ -1924,8 +1925,8 @@ class BuiltinVariable(VariableTracker):
|
||||
# VT(foo.__dict__). This simplifies the construction of the new
|
||||
# dict.
|
||||
args[0] = args[0].get_forwarded_dict(tx)
|
||||
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[VariableTracker.build(tx, user_cls), *args],
|
||||
kwargs,
|
||||
)
|
||||
@ -2021,7 +2022,7 @@ class BuiltinVariable(VariableTracker):
|
||||
):
|
||||
iter_fn = arg.var_getattr(tx, "__iter__")
|
||||
if isinstance(iter_fn, variables.UserMethodVariable):
|
||||
out = iter_fn.call_function(tx, list(args), kwargs)
|
||||
out = tx.inline_user_function_return(iter_fn, args, kwargs)
|
||||
if isinstance(out, SetVariable):
|
||||
return out
|
||||
return BuiltinVariable(set).call_set(tx, out)
|
||||
|
@ -1295,16 +1295,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
unimplemented_v2(
|
||||
gb_type="torch.fx.traceback.annotate escaped from compiled region",
|
||||
context=str(self),
|
||||
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
@ -189,8 +189,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
return VariableTracker.build(tx, polyfills.repeat).call_function(
|
||||
tx, args, kwargs
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.repeat), args, kwargs
|
||||
)
|
||||
elif self.value is itertools.count:
|
||||
return variables.CountIteratorVariable(
|
||||
|
@ -181,9 +181,11 @@ class BaseListVariable(VariableTracker):
|
||||
if not len(args):
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
return VariableTracker.build(tx, polyfills.index).call_function(
|
||||
tx, [self] + list(args), kwargs
|
||||
)
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.index),
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
elif name == "count":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
@ -542,9 +542,11 @@ class NNModuleVariable(VariableTracker):
|
||||
args = [self] + args
|
||||
else:
|
||||
assert istype(fn, types.FunctionType)
|
||||
return variables.UserFunctionVariable(
|
||||
fn, source=fn_source
|
||||
).call_function(tx, args, kwargs)
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn, source=fn_source),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
@ -771,8 +773,8 @@ class NNModuleVariable(VariableTracker):
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
src = AttrSource(AttrSource(self.source, name), "__func__")
|
||||
return variables.UserFunctionVariable(fn, source=src).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn, source=src),
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
@ -849,8 +851,8 @@ class NNModuleVariable(VariableTracker):
|
||||
# Inline the function
|
||||
fn = getattr(module, name).__func__
|
||||
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn, source=fn_source),
|
||||
[self] + args,
|
||||
kwargs,
|
||||
)
|
||||
@ -949,18 +951,13 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
# The program can mutate the nn module object but the saved `value`
|
||||
# will not reflect the mutations. So, trace through the `__iter__`
|
||||
# function to reflect any tracked mutations.
|
||||
|
||||
return (
|
||||
VariableTracker.build(tx, fn)
|
||||
.call_function(
|
||||
tx,
|
||||
[
|
||||
self,
|
||||
],
|
||||
{},
|
||||
)
|
||||
.unpack_var_sequence(tx)
|
||||
)
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, fn),
|
||||
[
|
||||
self,
|
||||
],
|
||||
{},
|
||||
).unpack_var_sequence(tx)
|
||||
|
||||
return super().unpack_var_sequence(tx)
|
||||
|
||||
|
@ -1085,8 +1085,8 @@ class TensorVariable(VariableTracker):
|
||||
if value is not None:
|
||||
from .. import polyfills
|
||||
|
||||
return VariableTracker.build(tx, polyfills.addcmul_inplace).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.addcmul_inplace),
|
||||
[self, tensor1, tensor2, value],
|
||||
{},
|
||||
)
|
||||
|
@ -568,16 +568,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
|
||||
@register(torch.ops.inductor.accumulate_grad_.default)
|
||||
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return VariableTracker.build(tx, polyfills.accumulate_grad).call_function(
|
||||
tx, args, kwargs
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
|
||||
)
|
||||
|
||||
@register(math.radians)
|
||||
def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if not check_unspec_or_constant_args(args, kwargs):
|
||||
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
|
||||
return VariableTracker.build(tx, polyfills.radians).call_function(
|
||||
tx, args, kwargs
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
@register(torch.is_inference_mode_enabled)
|
||||
@ -829,10 +829,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
_, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
|
||||
return VariableTracker.build(
|
||||
tx, polyfills.foreach_lerp_inplace
|
||||
).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
@ -842,10 +840,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
# In eager it's more performant to call item() from within the C op implementation
|
||||
# in compile, it's more performant to not graph break.
|
||||
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
|
||||
return VariableTracker.build(
|
||||
tx, polyfills.foreach_pow_scalar
|
||||
).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
@ -1972,8 +1968,8 @@ class FuncTorchInterpreterVariable(BaseTorchVariable):
|
||||
if name == "key":
|
||||
return variables.EnumVariable(self.value.key())
|
||||
elif name == "process":
|
||||
return VariableTracker.build(tx, self.value.process.__func__).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(self.value.process.__func__),
|
||||
[self] + args,
|
||||
kwargs,
|
||||
)
|
||||
|
@ -59,7 +59,7 @@ from ..utils import (
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .functions import UserMethodVariable
|
||||
from .functions import UserFunctionVariable, UserMethodVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
@ -620,7 +620,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
elif isinstance(attr, property):
|
||||
getter_source = AttrSource(attr_source, "fget")
|
||||
getter = attr.fget
|
||||
getter_var = VariableTracker.build(tx, getter, source=getter_source)
|
||||
getter_var = UserFunctionVariable(getter, source=getter_source)
|
||||
return getter_var.call_function(tx, [self], {})
|
||||
|
||||
elif isinstance(attr, classmethod):
|
||||
|
@ -490,8 +490,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
|
||||
return NullContextVariable(*args, **kwargs)
|
||||
elif self.value is collections.OrderedDict:
|
||||
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[self, *args],
|
||||
kwargs,
|
||||
)
|
||||
@ -823,10 +823,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return variables.MappingProxyVariable(args[0])
|
||||
elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source:
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
return VariableTracker.build(
|
||||
tx, polyfills.instantiate_user_defined_class_object
|
||||
).call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(
|
||||
tx, polyfills.instantiate_user_defined_class_object
|
||||
),
|
||||
[self, *args],
|
||||
kwargs,
|
||||
)
|
||||
@ -1803,8 +1803,8 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
|
||||
) -> "VariableTracker":
|
||||
fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
|
||||
args = [self] + args
|
||||
return fn_variable.call_function(
|
||||
tx,
|
||||
return tx.inline_user_function_return(
|
||||
fn_variable,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
@ -3147,82 +3147,35 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
for i, x in enumerate(input_nodes)
|
||||
}
|
||||
example_inputs = list(unique_example_inputs.values())
|
||||
example_inputs_extern = []
|
||||
for input_node in input_nodes:
|
||||
if unique_example_inputs[input_node.get_name()].is_mkldnn:
|
||||
example_inputs_extern.append(
|
||||
unique_example_inputs[input_node.get_name()]
|
||||
)
|
||||
else:
|
||||
base = unique_example_inputs[input_node.get_name()]
|
||||
base = base if base._base is None else base._base
|
||||
sizes = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
example_inputs_extern = [
|
||||
(
|
||||
unique_example_inputs[input_node.get_name()]
|
||||
if unique_example_inputs[input_node.get_name()].is_mkldnn
|
||||
else torch.as_strided(
|
||||
unique_example_inputs[input_node.get_name()],
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for size in input_node.get_size()
|
||||
)
|
||||
strides = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
stride,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for stride in input_node.get_stride()
|
||||
)
|
||||
storage_offset = V.graph.sizevars.atomically_apply_size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# Check if the required storage size exceeds the current storage
|
||||
# to avoid illegal memory access
|
||||
needed_size = torch._prims_common.compute_required_storage_length(
|
||||
sizes, strides, storage_offset
|
||||
)
|
||||
current_size = base.storage().size()
|
||||
|
||||
if needed_size > current_size:
|
||||
# Create a new base tensor with sufficient storage
|
||||
new_base = torch.randn(
|
||||
needed_size,
|
||||
dtype=base.dtype,
|
||||
device=base.device,
|
||||
requires_grad=base.requires_grad,
|
||||
)
|
||||
base = new_base.as_strided(
|
||||
base.size(), base.stride(), base.storage_offset()
|
||||
)
|
||||
|
||||
example_inputs_extern.append(
|
||||
torch.as_strided(base, sizes, strides, storage_offset)
|
||||
),
|
||||
V.graph.sizevars.size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
for input_node in input_nodes
|
||||
]
|
||||
out = cls.benchmark_example_value(layout, hint_override=hint_override)
|
||||
|
||||
# Also check the output tensor for storage size
|
||||
out_base = out if out._base is None else out._base
|
||||
out_offset = V.graph.sizevars.size_hint(layout.offset)
|
||||
needed_out_size = torch._prims_common.compute_required_storage_length(
|
||||
out.size(), out.stride(), out_offset
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
current_out_size = out_base.storage().size()
|
||||
|
||||
if needed_out_size > current_out_size:
|
||||
# Create a new base tensor with sufficient storage
|
||||
new_out_base = torch.randn(
|
||||
needed_out_size,
|
||||
dtype=out_base.dtype,
|
||||
device=out_base.device,
|
||||
requires_grad=out_base.requires_grad,
|
||||
)
|
||||
out_base = new_out_base.as_strided(
|
||||
out_base.size(), out_base.stride(), out_base.storage_offset()
|
||||
)
|
||||
|
||||
out_extern = torch.as_strided(out_base, out.size(), out.stride(), out_offset)
|
||||
expected = None
|
||||
if VERIFY:
|
||||
choices[0].benchmark(*example_inputs_extern, out=out_extern)
|
||||
@ -3663,13 +3616,10 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# So we need call as_strided in the end to 'view' the tensor with the correct
|
||||
# sizes/strides
|
||||
return AlgorithmSelectorCache.generate_example_value(
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for size in node.get_size()
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
@ -3682,20 +3632,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
node.layout.offset,
|
||||
node.layout.offset,
|
||||
V.graph.sizevars.size_hints(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
V.graph.get_allocation_size(node),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
size,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for size in V.graph.get_allocation_size(node)
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -230,9 +230,9 @@ class autocast:
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
self.fast_dtype = (
|
||||
torch.get_autocast_dtype(device_type) if dtype is None else dtype
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
self.fast_dtype = dtype
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
@ -243,9 +243,6 @@ class autocast:
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
|
||||
device_supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == self.custom_backend_name:
|
||||
necessary_funcs = [
|
||||
@ -262,55 +259,110 @@ class autocast:
|
||||
assert hasattr(self.custom_device_mod, func), (
|
||||
message + f"But the func `{func}` is missing. \n"
|
||||
)
|
||||
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
|
||||
|
||||
self._cache_enabled = (
|
||||
torch.is_autocast_cache_enabled()
|
||||
if cache_enabled is None
|
||||
else cache_enabled
|
||||
)
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if (
|
||||
enabled
|
||||
and self.device == "cuda"
|
||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
||||
):
|
||||
warnings.warn(
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
|
||||
device_name = (
|
||||
self.device
|
||||
if self.device == self.custom_backend_name
|
||||
else self.device.upper()
|
||||
)
|
||||
if enabled:
|
||||
# Special case for CUDA AMP and bfloat16 support
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.amp.common.amp_definitely_not_available():
|
||||
warnings.warn(
|
||||
"CUDA is not available or torch_xla is imported. Disabling autocast."
|
||||
)
|
||||
enabled = False
|
||||
elif (
|
||||
self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.fast_dtype not in device_supported_dtypes:
|
||||
error_message = (
|
||||
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
f"{device_name} Autocast only supports dtypes of "
|
||||
+ ", ".join(map(str, device_supported_dtypes))
|
||||
+ " currently."
|
||||
if self.device == "cpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype and enabled:
|
||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "CPU Autocast only supports dtype of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
# Special case for MPS bfloat16 support on macOS < 14
|
||||
if (
|
||||
self.device == "mps"
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
):
|
||||
elif self.device == "mtia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "maia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "ipu":
|
||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtypes:
|
||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "hpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "cuda":
|
||||
if (
|
||||
enabled
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.fast_dtype == torch.bfloat16:
|
||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||
"on macOS versions below 14. Disabling autocast."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.float16, torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -235,9 +235,10 @@ class _DerivedObserverOrFakeQuantize(ObserverBase):
|
||||
from .utils import is_per_channel
|
||||
|
||||
if is_per_channel(self.qscheme):
|
||||
assert self.ch_axis is not None, (
|
||||
"Must provide a valid ch_axis if qscheme is per channel"
|
||||
)
|
||||
if self.ch_axis is None:
|
||||
raise AssertionError(
|
||||
"Must provide a valid ch_axis if qscheme is per channel"
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x
|
||||
|
@ -92,9 +92,10 @@ def channel_range(input, axis=0):
|
||||
mins = min_over_ndim(input, axis_list)
|
||||
maxs = max_over_ndim(input, axis_list)
|
||||
|
||||
assert mins.size(0) == input.size(axis), (
|
||||
"Dimensions of resultant channel range does not match size of requested axis"
|
||||
)
|
||||
if mins.size(0) != input.size(axis):
|
||||
raise AssertionError(
|
||||
"Dimensions of resultant channel range does not match size of requested axis"
|
||||
)
|
||||
return maxs - mins
|
||||
|
||||
|
||||
|
@ -45,7 +45,8 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
|
||||
**observer_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert quant_min < quant_max, "quant_min must be strictly less than quant_max."
|
||||
if quant_min >= quant_max:
|
||||
raise AssertionError("quant_min must be strictly less than quant_max.")
|
||||
self.quant_min = quant_min
|
||||
self.quant_max = quant_max
|
||||
# also pass quant_min and quant_max to observer
|
||||
@ -56,19 +57,16 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
|
||||
self.scale = Parameter(torch.tensor([scale]))
|
||||
self.zero_point = Parameter(torch.tensor([zero_point]))
|
||||
else:
|
||||
assert isinstance(channel_len, int) and channel_len > 0, (
|
||||
"Channel size must be a positive integer."
|
||||
)
|
||||
if not (isinstance(channel_len, int) and channel_len > 0):
|
||||
raise AssertionError("Channel size must be a positive integer.")
|
||||
self.scale = Parameter(torch.tensor([scale] * channel_len))
|
||||
self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
|
||||
|
||||
self.activation_post_process = observer(**observer_kwargs)
|
||||
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, (
|
||||
"quant_min out of bound"
|
||||
)
|
||||
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, (
|
||||
"quant_max out of bound"
|
||||
)
|
||||
if not torch.iinfo(self.activation_post_process.dtype).min > quant_min:
|
||||
raise AssertionError("quant_min out of bound")
|
||||
if quant_max > torch.iinfo(self.activation_post_process.dtype).max:
|
||||
raise AssertionError("quant_max out of bound")
|
||||
self.dtype = self.activation_post_process.dtype
|
||||
self.qscheme = self.activation_post_process.qscheme
|
||||
self.ch_axis = (
|
||||
|
@ -88,9 +88,10 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
|
||||
>>> lr = nn.LeakyReLU(0.01)
|
||||
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
|
||||
"""
|
||||
assert linear.training == bn.training and bn.training == leaky_relu.training, (
|
||||
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
|
||||
)
|
||||
if linear.training != bn.training or bn.training != leaky_relu.training:
|
||||
raise AssertionError(
|
||||
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
|
||||
)
|
||||
|
||||
if is_qat:
|
||||
raise NotImplementedError(
|
||||
|
@ -164,10 +164,11 @@ def remove_boolean_dispatch_from_name(p) -> Any:
|
||||
return "torch.nn.functional.adaptive_max_pool2d"
|
||||
elif p is F.adaptive_max_pool3d:
|
||||
return "torch.nn.functional.adaptive_max_pool3d"
|
||||
assert "boolean_dispatch" not in str(p), (
|
||||
f"{p} does not have a human readable representation in "
|
||||
+ "quantization documentation"
|
||||
)
|
||||
if "boolean_dispatch" in str(p):
|
||||
raise AssertionError(
|
||||
f"{p} does not have a human readable representation in "
|
||||
+ "quantization documentation"
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
@ -300,7 +301,8 @@ def _get_fuser_method_in_reversed_nested_tuple_format(
|
||||
The first argument of a fuser method is always `is_qat` and is not affected
|
||||
in the conversion. We currently only support functions with 3 or 4 arguments.
|
||||
"""
|
||||
assert config.fuser_method is not None
|
||||
if config.fuser_method is None:
|
||||
raise AssertionError("config.fuser_method must be provided")
|
||||
if config._pattern_complex_format is not None:
|
||||
return config.fuser_method
|
||||
if not isinstance(config.pattern, tuple):
|
||||
|
@ -175,9 +175,10 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
super().__init__()
|
||||
# Populate quant_min/quant_max to observer_kwargs if valid
|
||||
if quant_min is not None and quant_max is not None:
|
||||
assert quant_min <= quant_max, (
|
||||
"quant_min must be less than or equal to quant_max"
|
||||
)
|
||||
if quant_min > quant_max:
|
||||
raise AssertionError(
|
||||
"quant_min must be less than or equal to quant_max"
|
||||
)
|
||||
dtype = observer_kwargs.get("dtype", torch.quint8)
|
||||
if hasattr(observer, "p"):
|
||||
# In case observer is _PartialWrapper, dtype can be stored in
|
||||
@ -186,9 +187,11 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
"dtype", dtype
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
||||
if torch.iinfo(dtype).min > quant_min:
|
||||
raise AssertionError("quant_min out of bound")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
||||
if quant_max > torch.iinfo(dtype).max:
|
||||
raise AssertionError("quant_max out of bound")
|
||||
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
||||
observer_kwargs["is_dynamic"] = is_dynamic
|
||||
self.activation_post_process = observer(**observer_kwargs)
|
||||
@ -210,11 +213,12 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
if hasattr(self.activation_post_process, "ch_axis")
|
||||
else -1
|
||||
)
|
||||
assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), (
|
||||
"Only per channel and per tensor quantization are supported in fake quantize"
|
||||
+ " got qscheme: "
|
||||
+ str(self.qscheme)
|
||||
)
|
||||
if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)):
|
||||
raise AssertionError(
|
||||
"Only per channel and per tensor quantization are supported in fake quantize"
|
||||
+ " got qscheme: "
|
||||
+ str(self.qscheme)
|
||||
)
|
||||
self.is_per_channel = _is_per_channel(self.qscheme)
|
||||
|
||||
@torch.jit.export
|
||||
@ -295,7 +299,10 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
if name == "scale":
|
||||
self.scale.resize_(val.shape)
|
||||
else:
|
||||
assert name == "zero_point"
|
||||
if name != "zero_point":
|
||||
raise AssertionError(
|
||||
"Expected 'zero_point' but got different state key"
|
||||
)
|
||||
self.zero_point.resize_(val.shape)
|
||||
# For torchscript module we need to update the attributes here since we do not
|
||||
# call the `_load_from_state_dict` function defined module.py
|
||||
@ -303,7 +310,10 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
if name == "scale":
|
||||
self.scale.copy_(val)
|
||||
else:
|
||||
assert name == "zero_point"
|
||||
if name != "zero_point":
|
||||
raise AssertionError(
|
||||
"Expected 'zero_point' but got different state key"
|
||||
)
|
||||
self.zero_point.copy_(val)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
@ -329,17 +339,19 @@ class FixedQParamsFakeQuantize(FakeQuantize):
|
||||
# TODO: rename observer to observer_ctr
|
||||
def __init__(self, observer):
|
||||
super().__init__(observer=observer)
|
||||
assert type(self.activation_post_process) is FixedQParamsObserver, (
|
||||
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
|
||||
)
|
||||
if type(self.activation_post_process) is not FixedQParamsObserver:
|
||||
raise AssertionError(
|
||||
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
|
||||
)
|
||||
self._observer_ctr = observer
|
||||
self.scale = self.activation_post_process.scale
|
||||
self.zero_point = self.activation_post_process.zero_point
|
||||
assert _is_per_tensor(self.qscheme), (
|
||||
"Only per tensor quantization is supported"
|
||||
+ " FixedQParamsFakeQuantize module, got qscheme:"
|
||||
+ str(self.qscheme)
|
||||
)
|
||||
if not _is_per_tensor(self.qscheme):
|
||||
raise AssertionError(
|
||||
"Only per tensor quantization is supported"
|
||||
+ " FixedQParamsFakeQuantize module, got qscheme:"
|
||||
+ str(self.qscheme)
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
@ -382,12 +394,13 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
|
||||
**observer_kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
self.activation_post_process,
|
||||
(MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
|
||||
), (
|
||||
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
||||
)
|
||||
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
|
||||
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
|
||||
self.is_symmetric_quant = _is_symmetric_quant(
|
||||
|
@ -35,9 +35,10 @@ def fuse_conv_bn(is_qat, conv, bn):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m2 = fuse_conv_bn(m1, b1)
|
||||
"""
|
||||
assert conv.training == bn.training, (
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
if conv.training != bn.training:
|
||||
raise AssertionError(
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
|
||||
fused_module_class_map = {
|
||||
nn.Conv1d: nni.ConvBn1d,
|
||||
@ -46,13 +47,18 @@ def fuse_conv_bn(is_qat, conv, bn):
|
||||
}
|
||||
|
||||
if is_qat:
|
||||
assert bn.num_features == conv.out_channels, (
|
||||
"Output channel of Conv2d must match num_features of BatchNorm2d"
|
||||
)
|
||||
assert bn.affine, "Only support fusing BatchNorm2d with affine set to True"
|
||||
assert bn.track_running_stats, (
|
||||
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
|
||||
)
|
||||
if bn.num_features != conv.out_channels:
|
||||
raise AssertionError(
|
||||
"Output channel of Conv2d must match num_features of BatchNorm2d."
|
||||
)
|
||||
if not bn.affine:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm2d with affine set to True"
|
||||
)
|
||||
if not bn.track_running_stats:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
|
||||
)
|
||||
fused_module_class = fused_module_class_map.get((type(conv)), None)
|
||||
if fused_module_class is not None:
|
||||
return fused_module_class(conv, bn)
|
||||
@ -81,9 +87,10 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
|
||||
"""
|
||||
assert conv.training == bn.training == relu.training, (
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
if not (conv.training == bn.training == relu.training):
|
||||
raise AssertionError(
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
fused_module: Optional[type[nn.Sequential]] = None
|
||||
if is_qat:
|
||||
map_to_fused_module_train = {
|
||||
@ -91,13 +98,18 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||
nn.Conv2d: nni.ConvBnReLU2d,
|
||||
nn.Conv3d: nni.ConvBnReLU3d,
|
||||
}
|
||||
assert bn.num_features == conv.out_channels, (
|
||||
"Output channel of Conv must match num_features of BatchNorm"
|
||||
)
|
||||
assert bn.affine, "Only support fusing BatchNorm with affine set to True"
|
||||
assert bn.track_running_stats, (
|
||||
"Only support fusing BatchNorm with tracking_running_stats set to True"
|
||||
)
|
||||
if bn.num_features != conv.out_channels:
|
||||
raise AssertionError(
|
||||
"Output channel of Conv2d must match num_features of BatchNorm2d"
|
||||
)
|
||||
if not bn.affine:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm2d with affine set to True"
|
||||
)
|
||||
if not bn.track_running_stats:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
|
||||
)
|
||||
fused_module = map_to_fused_module_train.get(type(conv), None)
|
||||
if fused_module is not None:
|
||||
return fused_module(conv, bn, relu)
|
||||
@ -134,18 +146,24 @@ def fuse_linear_bn(is_qat, linear, bn):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m2 = fuse_linear_bn(m1, b1)
|
||||
"""
|
||||
assert linear.training == bn.training, (
|
||||
"Linear and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
if linear.training != bn.training:
|
||||
raise AssertionError(
|
||||
"Linear and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
|
||||
if is_qat:
|
||||
assert bn.num_features == linear.out_features, (
|
||||
"Output features of Linear must match num_features of BatchNorm1d"
|
||||
)
|
||||
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
|
||||
assert bn.track_running_stats, (
|
||||
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
||||
)
|
||||
if bn.num_features != linear.out_features:
|
||||
raise AssertionError(
|
||||
"Output features of Linear must match num_features of BatchNorm1d"
|
||||
)
|
||||
if not bn.affine:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm1d with affine set to True"
|
||||
)
|
||||
if not bn.track_running_stats:
|
||||
raise AssertionError(
|
||||
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
||||
)
|
||||
return nni.LinearBn1d(linear, bn)
|
||||
else:
|
||||
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
||||
@ -167,9 +185,10 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m2 = fuse_convtranspose_bn(m1, b1)
|
||||
"""
|
||||
assert convt.training == bn.training, (
|
||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
if convt.training != bn.training:
|
||||
raise AssertionError(
|
||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||
)
|
||||
|
||||
if is_qat:
|
||||
raise Exception( # noqa: TRY002
|
||||
@ -224,7 +243,8 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
|
||||
_DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping
|
||||
)
|
||||
fuser_method = all_mappings.get(op_list, None)
|
||||
assert fuser_method is not None, f"did not find fuser method for: {op_list} "
|
||||
if fuser_method is None:
|
||||
raise AssertionError(f"did not find fuser method for: {op_list} ")
|
||||
return fuser_method
|
||||
|
||||
|
||||
@ -289,5 +309,6 @@ def get_fuser_method_new(
|
||||
fuser_method = fuser_method_mapping.get(op_pattern)
|
||||
if fuser_method is not None:
|
||||
break
|
||||
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
|
||||
if fuser_method is None:
|
||||
raise AssertionError(f"did not find fuser method for: {op_pattern} ")
|
||||
return fuser_method
|
||||
|
@ -249,17 +249,17 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
)
|
||||
self.reduce_range = reduce_range
|
||||
self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
|
||||
assert self.qscheme in (
|
||||
if self.qscheme not in (
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_symmetric,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
), (
|
||||
"Default Observer only works for per_tensor_affine, \
|
||||
per_tensor_symmetric, per_channel_affine, \
|
||||
per_channel_symmetric and per_channel_float_qparams quantization scheme"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
"Default Observer only works for per_tensor_affine, per_tensor_symmetric, "
|
||||
"per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme"
|
||||
)
|
||||
|
||||
_ALLOWED_DTYPES = (
|
||||
torch.qint8,
|
||||
@ -275,9 +275,10 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
torch.uint16,
|
||||
)
|
||||
|
||||
assert self.dtype in _ALLOWED_DTYPES, (
|
||||
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
|
||||
)
|
||||
if self.dtype not in _ALLOWED_DTYPES:
|
||||
raise AssertionError(
|
||||
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
|
||||
)
|
||||
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
||||
if self.has_customized_qrange:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
@ -336,12 +337,12 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
"""
|
||||
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
|
||||
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
|
||||
assert quant_min <= 0 <= quant_max, (
|
||||
"Used-specified quantization range must include 0."
|
||||
)
|
||||
assert quant_min < quant_max, (
|
||||
"qmin must be strictly less than qmax for user-specified quantization range."
|
||||
)
|
||||
if not quant_min <= 0 <= quant_max:
|
||||
raise AssertionError("Used-specified quantization range must include 0.")
|
||||
if quant_min >= quant_max:
|
||||
raise AssertionError(
|
||||
"qmin must be strictly less than qmax for user-specified quantization range."
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def _calculate_qparams(
|
||||
@ -1131,7 +1132,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
|
||||
caffe2/quantization/server/norm_minimization.cc
|
||||
"""
|
||||
assert self.histogram.size()[0] == self.bins, "bins mismatch"
|
||||
if self.histogram.size()[0] != self.bins:
|
||||
raise AssertionError("bins mismatch")
|
||||
bin_width = (self.max_val - self.min_val) / self.bins
|
||||
|
||||
# cumulative sum
|
||||
@ -1252,8 +1254,10 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
return transformed_orig_hist + update_hist
|
||||
|
||||
# We assume the update_hist is already in the target range, we will map the orig_max to it
|
||||
assert update_min <= orig_min
|
||||
assert update_max >= orig_max
|
||||
if update_min > orig_min:
|
||||
raise AssertionError("update_min must be <= orig_min")
|
||||
if update_max < orig_max:
|
||||
raise AssertionError("update_max must be >= orig_max")
|
||||
|
||||
# Now we need to turn the old_histogram, into the range of the new histogram
|
||||
transformed_orig_hist = self._upscale_histogram(
|
||||
@ -1273,9 +1277,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.max_val.copy_(max_val)
|
||||
assert min_val.numel() == 1 and max_val.numel() == 1, (
|
||||
"histogram min/max values must be scalar."
|
||||
)
|
||||
if min_val.numel() != 1 or max_val.numel() != 1:
|
||||
raise AssertionError("histogram min/max values must be scalar.")
|
||||
new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type]
|
||||
self.histogram.detach_().resize_(new_histogram.shape)
|
||||
self.histogram.copy_(new_histogram)
|
||||
@ -1350,10 +1353,11 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
|
||||
[0], device=self.min_val.device.type
|
||||
)
|
||||
assert self.bins == len(self.histogram), (
|
||||
"The number of bins in histogram should be equal to the number of bins "
|
||||
"supplied while making this observer"
|
||||
)
|
||||
if self.bins != len(self.histogram):
|
||||
raise AssertionError(
|
||||
"The number of bins in histogram should be equal to the number of bins "
|
||||
"supplied while making this observer"
|
||||
)
|
||||
|
||||
new_min, new_max = self._non_linear_param_search()
|
||||
|
||||
@ -1785,9 +1789,10 @@ def get_block_size(
|
||||
input_shape: The input tensor shape possibly more than 2 dimensions
|
||||
granularity: The granularity type of the quantization
|
||||
"""
|
||||
assert isinstance(granularity, Granularity), (
|
||||
"Please provide an instance of Granularity, not subclass of it"
|
||||
)
|
||||
if not isinstance(granularity, Granularity):
|
||||
raise AssertionError(
|
||||
"Please provide an instance of Granularity, not subclass of it"
|
||||
)
|
||||
if isinstance(granularity, PerTensor):
|
||||
return input_shape
|
||||
elif isinstance(granularity, PerAxis):
|
||||
@ -1797,9 +1802,10 @@ def get_block_size(
|
||||
elif isinstance(granularity, PerRow):
|
||||
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
|
||||
elif isinstance(granularity, PerGroup):
|
||||
assert len(input_shape) == 2, (
|
||||
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
|
||||
)
|
||||
if len(input_shape) != 2:
|
||||
raise AssertionError(
|
||||
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
|
||||
)
|
||||
return (1, granularity.group_size)
|
||||
elif isinstance(granularity, PerToken):
|
||||
block_size = [1] * len(input_shape)
|
||||
@ -1836,8 +1842,8 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert granularity is not None, "granularity is None"
|
||||
|
||||
if granularity is None:
|
||||
raise AssertionError("granularity is None")
|
||||
self.mapping_type = mapping_type
|
||||
self.target_dtype = target_dtype
|
||||
self.granularity = granularity
|
||||
@ -1875,10 +1881,10 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
|
||||
from torch.ao.quantization.fx.utils import create_getattr_from_value
|
||||
|
||||
with model.graph.inserting_before(observer_node):
|
||||
assert self.block_size is not None, "Expecting block_size to be populated"
|
||||
assert self.original_dtype is not None, (
|
||||
"Expecting original_dtype to be populated"
|
||||
)
|
||||
if self.block_size is None:
|
||||
raise AssertionError("Expecting block_size to be populated")
|
||||
if self.original_dtype is None:
|
||||
raise AssertionError("Expecting original_dtype to be populated")
|
||||
if hasattr(self, "is_dynamic") and self.is_dynamic:
|
||||
choose_qparams_affine = model.graph.call_function(
|
||||
torch.ops.pt2e_quant.choose_qparams_affine,
|
||||
|
@ -565,9 +565,10 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
|
||||
torch.ao.quantization.MovingAveragePerChannelMinMaxObserver,
|
||||
),
|
||||
)
|
||||
assert not is_per_channel, (
|
||||
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
|
||||
)
|
||||
if is_per_channel:
|
||||
raise AssertionError(
|
||||
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
@ -599,7 +600,8 @@ def _add_module_to_qconfig_obs_ctr(
|
||||
return qconfig
|
||||
|
||||
def get_factory_kwargs_based_on_module_device():
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise AssertionError("module must be an instance of torch.nn.Module")
|
||||
devices = {p.device for p in module.parameters()} | {
|
||||
p.device for p in module.buffers()
|
||||
}
|
||||
@ -671,7 +673,10 @@ def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
|
||||
if q1 is None or q2 is None:
|
||||
return q1 == q2
|
||||
else:
|
||||
assert q1 is not None and q2 is not None
|
||||
if q1 is None or q2 is None:
|
||||
raise AssertionError(
|
||||
"Both q1 and q2 must be non-None for qconfig comparison"
|
||||
)
|
||||
try:
|
||||
# Qconfig weight and activation can be either a partial wrapper,
|
||||
# or an observer class. Special handling is required (above) for
|
||||
|
@ -252,10 +252,11 @@ def get_static_quant_module_class(
|
||||
additional_static_quant_mapping,
|
||||
)
|
||||
static_quant_module_class = all_mappings.get(float_module_class, None)
|
||||
assert static_quant_module_class is not None, (
|
||||
f"Floating point module class {str(float_module_class)}"
|
||||
+ " does not have a corresponding quantized module class"
|
||||
)
|
||||
if static_quant_module_class is None:
|
||||
raise AssertionError(
|
||||
f"Floating point module class {str(float_module_class)}"
|
||||
+ " does not have a corresponding quantized module class"
|
||||
)
|
||||
return copy.deepcopy(static_quant_module_class)
|
||||
|
||||
|
||||
@ -272,10 +273,11 @@ def get_dynamic_quant_module_class(
|
||||
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping
|
||||
)
|
||||
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
|
||||
assert dynamic_quant_module_class is not None, (
|
||||
f"Floating point module class {str(float_module_class)}"
|
||||
+ " does not have a corresponding quantized module class"
|
||||
)
|
||||
if dynamic_quant_module_class is None:
|
||||
raise AssertionError(
|
||||
f"Floating point module class {str(float_module_class)}"
|
||||
+ " does not have a corresponding quantized module class"
|
||||
)
|
||||
return copy.deepcopy(dynamic_quant_module_class)
|
||||
|
||||
|
||||
@ -344,9 +346,10 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
|
||||
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
|
||||
"""Get the quantized operator corresponding to the float operator"""
|
||||
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
|
||||
assert quantized_op is not None, (
|
||||
f"Operator {str(float_op)} does not have corresponding quantized op"
|
||||
)
|
||||
if quantized_op is None:
|
||||
raise AssertionError(
|
||||
f"Operator {str(float_op)} does not have corresponding quantized op"
|
||||
)
|
||||
return quantized_op
|
||||
|
||||
|
||||
|
@ -158,9 +158,10 @@ def _observer_forward_pre_hook(self, input):
|
||||
|
||||
|
||||
def _register_activation_post_process_hook(module, pre_hook=False):
|
||||
assert hasattr(module, "activation_post_process"), (
|
||||
"Expect activation_post_process attribute already attached to the module"
|
||||
)
|
||||
if not hasattr(module, "activation_post_process"):
|
||||
raise AssertionError(
|
||||
"Expect activation_post_process attribute already attached to the module"
|
||||
)
|
||||
if pre_hook:
|
||||
module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True)
|
||||
else:
|
||||
@ -198,9 +199,10 @@ def _add_observer_(
|
||||
# respect device affinity when adding observers
|
||||
if device is None:
|
||||
devices = _get_unique_devices_(module)
|
||||
assert len(devices) <= 1, (
|
||||
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
if len(devices) > 1:
|
||||
raise AssertionError(
|
||||
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
|
||||
def get_activation_post_process(qconfig, device, special_act_post_process=None):
|
||||
@ -243,9 +245,10 @@ def _add_observer_(
|
||||
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
|
||||
):
|
||||
if needs_observation(child):
|
||||
assert hasattr(child, "activation_post_process"), (
|
||||
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
|
||||
)
|
||||
if not hasattr(child, "activation_post_process"):
|
||||
raise AssertionError(
|
||||
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
|
||||
)
|
||||
child.activation_post_process = get_activation_post_process(
|
||||
child.qconfig, device
|
||||
)
|
||||
@ -584,7 +587,8 @@ def prepare_qat(model, mapping=None, inplace=False):
|
||||
is mutated
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
|
||||
assert model.training, "prepare_qat only works on models in training mode"
|
||||
if not model.training:
|
||||
raise AssertionError("prepare_qat only works on models in training mode")
|
||||
if mapping is None:
|
||||
mapping = get_default_qat_module_mappings()
|
||||
|
||||
@ -760,7 +764,10 @@ def swap_module(
|
||||
elif type_before_parametrizations(mod) in mapping:
|
||||
qmod = mapping[type_before_parametrizations(mod)]
|
||||
if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
|
||||
assert mod.qconfig is not None
|
||||
if mod.qconfig is None:
|
||||
raise AssertionError(
|
||||
"module qconfig must not be None when swapping to reference module"
|
||||
)
|
||||
weight_post_process = mod.qconfig.weight()
|
||||
weight_post_process(mod.weight)
|
||||
weight_qparams = get_qparam_dict(weight_post_process)
|
||||
@ -787,11 +794,13 @@ def swap_module(
|
||||
|
||||
# respect device affinity when swapping modules
|
||||
devices = _get_unique_devices_(mod)
|
||||
assert len(devices) <= 1 or (
|
||||
len(devices) == 2 and torch.device("meta") in devices
|
||||
), (
|
||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
if not (
|
||||
len(devices) <= 1
|
||||
or (len(devices) == 2 and torch.device("meta") in devices)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
if device:
|
||||
new_mod.to(device)
|
||||
|
@ -157,12 +157,12 @@ def _convert_ondevice_jit(
|
||||
model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC
|
||||
):
|
||||
_check_is_script_module(model)
|
||||
assert quant_type == QuantType.DYNAMIC, (
|
||||
"This API, while should work for static quant, is only tested for dynamic quant."
|
||||
)
|
||||
assert not method_name.startswith("observe_"), (
|
||||
"Pass in valid method to be quantized, e.g. forward"
|
||||
)
|
||||
if quant_type != QuantType.DYNAMIC:
|
||||
raise AssertionError(
|
||||
"This API, while should work for static quant, is only tested for dynamic quant."
|
||||
)
|
||||
if method_name.startswith("observe_"):
|
||||
raise AssertionError("Pass in valid method to be quantized, e.g. forward")
|
||||
observe_method_name = "observe_" + method_name
|
||||
quantize_method_name = "quantize_" + method_name
|
||||
model_c = model._c
|
||||
@ -230,12 +230,14 @@ def _quantize_jit(
|
||||
model = prepare_dynamic_jit(model, qconfig_dict, inplace)
|
||||
model = convert_dynamic_jit(model, True, debug)
|
||||
else:
|
||||
assert run_fn, (
|
||||
"Must provide calibration function for post training static quantization"
|
||||
)
|
||||
assert run_args, (
|
||||
"Must provide calibration dataset for post training static quantization"
|
||||
)
|
||||
if not run_fn:
|
||||
raise AssertionError(
|
||||
"Must provide calibration function for post training static quantization"
|
||||
)
|
||||
if not run_args:
|
||||
raise AssertionError(
|
||||
"Must provide calibration dataset for post training static quantization"
|
||||
)
|
||||
model = prepare_jit(model, qconfig_dict, inplace)
|
||||
run_fn(model, *run_args)
|
||||
model = convert_jit(model, True, debug)
|
||||
|
@ -263,7 +263,10 @@ def _is_quantized_op_pt2e(node: torch.fx.Node):
|
||||
# The node has not been annotated, directly return False
|
||||
return False
|
||||
quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
|
||||
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
|
||||
if not isinstance(quantization_annotation, _X86InductorQuantizationAnnotation):
|
||||
raise AssertionError(
|
||||
"quantization_annotation must be an _X86InductorQuantizationAnnotation"
|
||||
)
|
||||
return quantization_annotation._is_output_of_quantized_pattern
|
||||
|
||||
|
||||
@ -428,20 +431,22 @@ class X86InductorQuantizer(Quantizer):
|
||||
if qat_state is None:
|
||||
qat_state = qconfig.is_qat
|
||||
else:
|
||||
assert qat_state == qconfig.is_qat, (
|
||||
f"All non-None quantization configs should have the same `is_qat`,"
|
||||
f"but got {qat_state} and {qconfig.is_qat}."
|
||||
)
|
||||
if qat_state != qconfig.is_qat:
|
||||
raise AssertionError(
|
||||
f"All non-None quantization configs should have the same `is_qat`,"
|
||||
f"but got {qat_state} and {qconfig.is_qat}."
|
||||
)
|
||||
# Query the `is_dynamic` state
|
||||
input_activation_spec = qconfig.input_activation
|
||||
if input_activation_spec is not None:
|
||||
if dynamic_state is None:
|
||||
dynamic_state = input_activation_spec.is_dynamic
|
||||
else:
|
||||
assert dynamic_state == input_activation_spec.is_dynamic, (
|
||||
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
|
||||
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
|
||||
)
|
||||
if dynamic_state != input_activation_spec.is_dynamic:
|
||||
raise AssertionError(
|
||||
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
|
||||
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
|
||||
)
|
||||
return _CurrentQuantizationMode(
|
||||
qat_state=qat_state, dynamic_state=dynamic_state
|
||||
)
|
||||
@ -567,10 +572,12 @@ class X86InductorQuantizer(Quantizer):
|
||||
return
|
||||
input_qspec_map = {}
|
||||
input_node = conv_node.args[0]
|
||||
assert isinstance(input_node, Node)
|
||||
if not isinstance(input_node, Node):
|
||||
raise AssertionError("input_node must be a FX Node")
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
weight_node = conv_node.args[1]
|
||||
assert isinstance(weight_node, Node)
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError("weight_node must be a FX Node")
|
||||
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
|
||||
bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
|
||||
if isinstance(bias_node, Node):
|
||||
@ -598,18 +605,23 @@ class X86InductorQuantizer(Quantizer):
|
||||
_annotate_nodes_not_quantize(linear_node)
|
||||
return
|
||||
input_qspec_map = {}
|
||||
assert linear_node.target == torch.ops.aten.linear.default
|
||||
if linear_node.target != torch.ops.aten.linear.default:
|
||||
raise AssertionError(
|
||||
"linear_node.target must be torch.ops.aten.linear.default"
|
||||
)
|
||||
has_bias = len(linear_node.args) == 3
|
||||
input_index = 0
|
||||
weight_index = 1
|
||||
bias_index = 2
|
||||
|
||||
input_node = linear_node.args[input_index]
|
||||
assert isinstance(input_node, Node)
|
||||
if not isinstance(input_node, Node):
|
||||
raise AssertionError("input_node must be a FX Node")
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
|
||||
weight_node = linear_node.args[weight_index]
|
||||
assert isinstance(weight_node, Node)
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError("weight_node must be a FX Node")
|
||||
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
|
||||
|
||||
bias_node = linear_node.args[bias_index] if has_bias else None
|
||||
@ -637,7 +649,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
if len(partition.output_nodes) > 1:
|
||||
raise ValueError("Input partition has more than one output node")
|
||||
output_node = partition.output_nodes[0]
|
||||
assert isinstance(output_node, Node)
|
||||
if not isinstance(output_node, Node):
|
||||
raise AssertionError("output_node must be a FX Node")
|
||||
output_node_list.append(output_node)
|
||||
if len(output_node_list) != len(partition_list):
|
||||
raise ValueError(
|
||||
@ -666,7 +679,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
conv_gemm_node_idx = 1
|
||||
extra_input_node_idx = 0
|
||||
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
|
||||
assert isinstance(extra_input_node, Node)
|
||||
if not isinstance(extra_input_node, Node):
|
||||
raise AssertionError("extra_input_node must be a FX Node")
|
||||
return conv_gemm_node_idx, extra_input_node_idx
|
||||
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
@ -1123,7 +1137,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
if conv_node != binary_node.args[conv_node_idx]:
|
||||
raise ValueError(f"{conv_node} doesn't match input of binary node")
|
||||
extra_input_node = binary_node.args[extra_input_node_idx]
|
||||
assert isinstance(conv_node, Node)
|
||||
if not isinstance(conv_node, Node):
|
||||
raise AssertionError("conv_node must be a FX Node")
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.conv2d.default
|
||||
@ -1237,7 +1252,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
return
|
||||
|
||||
input_node = maxpool_node.args[0]
|
||||
assert isinstance(input_node, Node)
|
||||
if not isinstance(input_node, Node):
|
||||
raise AssertionError("input_node must be a FX Node")
|
||||
input_qspec_map = {}
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
@ -1254,11 +1270,14 @@ class X86InductorQuantizer(Quantizer):
|
||||
return
|
||||
cat_node = node
|
||||
input_nodes = cat_node.args[0]
|
||||
assert isinstance(input_nodes, Sequence)
|
||||
if not isinstance(input_nodes, Sequence):
|
||||
raise AssertionError("input_nodes must be a Sequence of FX Nodes")
|
||||
first_input_node = input_nodes[0]
|
||||
input_qspec_map = {}
|
||||
assert isinstance(first_input_node, Node)
|
||||
assert isinstance(cat_node, Node)
|
||||
if not isinstance(first_input_node, Node):
|
||||
raise AssertionError("first_input_node must be a FX Node")
|
||||
if not isinstance(cat_node, Node):
|
||||
raise AssertionError("cat_node must be a FX Node")
|
||||
input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
|
||||
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
|
||||
(first_input_node, cat_node)
|
||||
@ -1267,7 +1286,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
for input_node in input_nodes[1:]:
|
||||
if input_node not in input_qspec_map:
|
||||
# There has the case of cat same nodes: torch.cat([input0, input0], 1)
|
||||
assert isinstance(input_node, Node)
|
||||
if not isinstance(input_node, Node):
|
||||
raise AssertionError("input_node must be a FX Node")
|
||||
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
|
||||
|
||||
cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
@ -1405,8 +1425,10 @@ class X86InductorQuantizer(Quantizer):
|
||||
):
|
||||
# Annotate the output_qspec of getitem_node
|
||||
input_act = maxpool_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
assert isinstance(maxpool_node, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input_act must be a FX Node")
|
||||
if not isinstance(maxpool_node, Node):
|
||||
raise AssertionError("maxpool_node must be a FX Node")
|
||||
edge_or_node = (input_act, maxpool_node)
|
||||
maxpool_node_quantization_annotation.output_qspec = (
|
||||
SharedQuantizationSpec(edge_or_node)
|
||||
@ -1534,7 +1556,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
raise ValueError(
|
||||
f"{linear_node} doesn't match input of binary node"
|
||||
)
|
||||
assert isinstance(linear_node, Node)
|
||||
if not isinstance(linear_node, Node):
|
||||
raise AssertionError("linear_node must be a FX Node")
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
|
@ -347,9 +347,8 @@ class XNNPACKQuantizer(Quantizer):
|
||||
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
|
||||
patterns in the submodule with this module name with the given `quantization_config`
|
||||
"""
|
||||
assert quantization_config is not None, (
|
||||
" quantization_config == None is not supported yet"
|
||||
)
|
||||
if quantization_config is None:
|
||||
raise AssertionError("quantization_config == None is not supported yet")
|
||||
self.module_name_config[module_name] = quantization_config
|
||||
return self
|
||||
|
||||
|
@ -121,10 +121,13 @@ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config.input_activation is None:
|
||||
return None
|
||||
quantization_spec: QuantizationSpec = quantization_config.input_activation
|
||||
assert quantization_spec.qscheme in [
|
||||
if quantization_spec.qscheme not in [
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
]
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
|
||||
)
|
||||
return quantization_spec
|
||||
|
||||
|
||||
@ -134,17 +137,21 @@ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config.output_activation is None:
|
||||
return None
|
||||
quantization_spec: QuantizationSpec = quantization_config.output_activation
|
||||
assert quantization_spec.qscheme in [
|
||||
if quantization_spec.qscheme not in [
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
]
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
|
||||
)
|
||||
return quantization_spec
|
||||
|
||||
|
||||
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
assert quantization_config is not None
|
||||
if quantization_config is None:
|
||||
raise AssertionError("quantization_config must not be None")
|
||||
if quantization_config.weight is None:
|
||||
return None
|
||||
quantization_spec: QuantizationSpec = quantization_config.weight
|
||||
@ -162,13 +169,15 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
assert quantization_config is not None
|
||||
if quantization_config is None:
|
||||
raise AssertionError("quantization_config must not be None")
|
||||
if quantization_config.bias is None:
|
||||
return None
|
||||
quantization_spec: QuantizationSpec = quantization_config.bias
|
||||
assert quantization_spec.dtype == torch.float, (
|
||||
"Only float dtype for bias is supported for bias right now"
|
||||
)
|
||||
if quantization_spec.dtype != torch.float:
|
||||
raise AssertionError(
|
||||
"Only float dtype for bias is supported for bias right now"
|
||||
)
|
||||
return quantization_spec
|
||||
|
||||
|
||||
@ -253,11 +262,13 @@ def _annotate_linear_relu(
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act = linear_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input activation must be a FX Node")
|
||||
input_qspec_map[input_act] = input_act_qspec
|
||||
|
||||
weight = linear_node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
if not isinstance(weight, Node):
|
||||
raise AssertionError("weight must be a FX Node")
|
||||
input_qspec_map[weight] = weight_qspec
|
||||
|
||||
# adding weight node to the partition as well
|
||||
@ -303,11 +314,13 @@ def _annotate_conv(
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act = conv_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input activation must be a FX Node")
|
||||
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
||||
|
||||
weight = conv_node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
if not isinstance(weight, Node):
|
||||
raise AssertionError("weight must be a FX Node")
|
||||
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
||||
|
||||
# adding weight node to the partition as well
|
||||
@ -362,11 +375,13 @@ def _do_annotate_conv_relu(
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act = conv_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input activation must be a FX Node")
|
||||
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
||||
|
||||
weight = conv_node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
if not isinstance(weight, Node):
|
||||
raise AssertionError("weight must be a FX Node")
|
||||
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
||||
|
||||
# adding weight node to the partition as well
|
||||
@ -635,8 +650,10 @@ def _annotate_gru_io_only(
|
||||
# subgraph
|
||||
input_act = input_nodes[0]
|
||||
input_act_user = next(iter(input_act.users.keys()))
|
||||
assert isinstance(input_act, Node)
|
||||
assert isinstance(input_act_user, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input activation must be a FX Node")
|
||||
if not isinstance(input_act_user, Node):
|
||||
raise AssertionError("input activation user must be a FX Node")
|
||||
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: get_input_act_qspec(quantization_config),
|
||||
@ -646,8 +663,10 @@ def _annotate_gru_io_only(
|
||||
|
||||
hidden_state = input_nodes[1]
|
||||
hidden_state_user = next(iter(hidden_state.users.keys()))
|
||||
assert isinstance(hidden_state, Node)
|
||||
assert isinstance(hidden_state_user, Node)
|
||||
if not isinstance(hidden_state, Node):
|
||||
raise AssertionError("hidden state must be a FX Node")
|
||||
if not isinstance(hidden_state_user, Node):
|
||||
raise AssertionError("hidden state user must be a FX Node")
|
||||
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
hidden_state: get_input_act_qspec(quantization_config),
|
||||
@ -655,7 +674,8 @@ def _annotate_gru_io_only(
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
|
||||
if len(output_nodes) != 2:
|
||||
raise AssertionError("expecting GRU to have two outputs")
|
||||
for output in output_nodes:
|
||||
output.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=get_output_act_qspec(quantization_config),
|
||||
@ -691,7 +711,8 @@ def _annotate_adaptive_avg_pool2d(
|
||||
|
||||
annotated_partitions.append(partition.nodes)
|
||||
input_act = pool_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
if not isinstance(input_act, Node):
|
||||
raise AssertionError("input activation must be a FX Node")
|
||||
|
||||
# only annotate input output sharing operator
|
||||
# when the output of the input node is annotated
|
||||
|
@ -214,7 +214,8 @@ def to_underlying_dtype(qdtype):
|
||||
torch.float8_e5m2: torch.float8_e5m2,
|
||||
torch.float8_e4m3fn: torch.float8_e4m3fn,
|
||||
}
|
||||
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
|
||||
if qdtype not in DTYPE_MAPPING:
|
||||
raise AssertionError("Unsupported dtype: " + str(qdtype))
|
||||
return DTYPE_MAPPING[qdtype]
|
||||
|
||||
|
||||
@ -269,21 +270,24 @@ def get_swapped_custom_module_class(
|
||||
"""
|
||||
quant_type = get_quant_type(qconfig)
|
||||
class_mapping = custom_module_class_mapping.get(quant_type, {})
|
||||
assert type(custom_module) in class_mapping, (
|
||||
"did not find corresponding observed "
|
||||
f"module class for {type(custom_module)} in mapping: {class_mapping}"
|
||||
)
|
||||
if type(custom_module) not in class_mapping:
|
||||
raise AssertionError(
|
||||
"did not find corresponding observed "
|
||||
f"module class for {type(custom_module)} in mapping: {class_mapping}"
|
||||
)
|
||||
return class_mapping[type(custom_module)]
|
||||
|
||||
|
||||
def activation_dtype(qconfig):
|
||||
assert qconfig is not None
|
||||
if qconfig is None:
|
||||
raise AssertionError("qconfig must be provided to determine activation dtype")
|
||||
activation = qconfig.activation()
|
||||
return activation.dtype
|
||||
|
||||
|
||||
def weight_dtype(qconfig):
|
||||
assert qconfig is not None
|
||||
if qconfig is None:
|
||||
raise AssertionError("qconfig must be provided to determine weight dtype")
|
||||
weight = qconfig.weight()
|
||||
return weight.dtype
|
||||
|
||||
@ -377,7 +381,8 @@ def get_qconfig_dtypes(qconfig):
|
||||
r"""returns the qconfig tuple for qconfig:
|
||||
(activation_dtype, weight_dtype, activation_is_dynamic)
|
||||
"""
|
||||
assert qconfig is not None
|
||||
if qconfig is None:
|
||||
raise AssertionError("qconfig must be provided to extract dtypes")
|
||||
activation = qconfig.activation()
|
||||
weight = qconfig.weight()
|
||||
act_is_dynamic = getattr(activation, "is_dynamic", False)
|
||||
@ -385,7 +390,8 @@ def get_qconfig_dtypes(qconfig):
|
||||
|
||||
|
||||
def get_quant_type(qconfig):
|
||||
assert qconfig is not None
|
||||
if qconfig is None:
|
||||
raise AssertionError("qconfig must be provided to determine quant type")
|
||||
activation = qconfig.activation()
|
||||
weight = qconfig.weight()
|
||||
static_dtypes = [
|
||||
@ -440,11 +446,11 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
|
||||
|
||||
return False
|
||||
|
||||
assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
|
||||
if min_val > max_val:
|
||||
raise AssertionError(f"min {min_val} should be less than max {max_val}")
|
||||
else:
|
||||
assert torch.all(min_val <= max_val), (
|
||||
f"min {min_val} should be less than max {max_val}"
|
||||
)
|
||||
if torch.any(min_val > max_val):
|
||||
raise AssertionError(f"min {min_val} should be less than max {max_val}")
|
||||
|
||||
return True
|
||||
|
||||
@ -479,13 +485,15 @@ def calculate_qmin_qmax(
|
||||
|
||||
qrange_len = initial_quant_max - initial_quant_min + 1
|
||||
if dtype in [torch.qint8, torch.int8]:
|
||||
assert 0 < qrange_len <= 256, (
|
||||
"quantization range should be positive and not exceed the maximum bit range (=256)."
|
||||
)
|
||||
if not (0 < qrange_len <= 256):
|
||||
raise AssertionError(
|
||||
"quantization range should be positive and not exceed the maximum bit range (=256)."
|
||||
)
|
||||
elif dtype in [torch.qint32, torch.int32]:
|
||||
assert 0 < qrange_len <= 2**32, (
|
||||
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
|
||||
)
|
||||
if not (0 < qrange_len <= 2**32):
|
||||
raise AssertionError(
|
||||
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
|
||||
)
|
||||
if reduce_range:
|
||||
quant_min, quant_max = quant_min // 2, quant_max // 2
|
||||
else:
|
||||
@ -633,12 +641,12 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
|
||||
"""
|
||||
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
|
||||
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
|
||||
assert quant_min <= 0 <= quant_max, (
|
||||
"Used-specified quantization range must include 0."
|
||||
)
|
||||
assert quant_min < quant_max, (
|
||||
"qmin must be strictly less than qmax for user-specified quantization range."
|
||||
)
|
||||
if not (quant_min <= 0 <= quant_max):
|
||||
raise AssertionError("Used-specified quantization range must include 0.")
|
||||
if quant_min >= quant_max:
|
||||
raise AssertionError(
|
||||
"qmin must be strictly less than qmax for user-specified quantization range."
|
||||
)
|
||||
|
||||
|
||||
# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
|
||||
@ -810,10 +818,11 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
||||
)
|
||||
devices = {torch.device("cpu")}
|
||||
""
|
||||
assert len(devices) <= 1, (
|
||||
"prepare only works with cpu or single-device CUDA modules, "
|
||||
f"but got devices {devices}"
|
||||
)
|
||||
if len(devices) > 1:
|
||||
raise AssertionError(
|
||||
"prepare only works with cpu or single-device CUDA modules, "
|
||||
f"but got devices {devices}"
|
||||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
return device
|
||||
|
||||
|
@ -419,7 +419,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
}
|
||||
|
||||
// Do not call this directly, use ProcessGroup::setGroupName instead.
|
||||
virtual void setGroupUid(const std::string& pg_uid) {
|
||||
void setGroupUid(const std::string& pg_uid) {
|
||||
pg_uid_ = pg_uid;
|
||||
}
|
||||
|
||||
|
@ -1,15 +1,11 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||
struct NodeBase;
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
@ -167,41 +163,7 @@ struct NodeBase {
|
||||
PyObject* users;
|
||||
PyObject* _repr_fn;
|
||||
PyObject* meta;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||
|
||||
inline NodeSortKey& sort_key() {
|
||||
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||
}
|
||||
|
||||
inline void set_prev(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _prev;
|
||||
_prev = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
inline void set_next(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _next;
|
||||
_next = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
// Equivalent to:
|
||||
// p, n = self._prev, self._next
|
||||
// p._next, n._prev = n, p
|
||||
inline void remove_from_list() {
|
||||
if (this->_prev == this && this->_next == this) {
|
||||
return;
|
||||
}
|
||||
NodeBase* p = this->_prev;
|
||||
NodeBase* n = this->_next;
|
||||
p->set_next(n);
|
||||
n->set_prev(p);
|
||||
}
|
||||
PyObject* _sort_key;
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
@ -211,8 +173,6 @@ static PyObject* NodeBase_new(
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||
NodeSortKey(); // placement new does not allocate
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -241,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->users = PyDict_New();
|
||||
self->_repr_fn = Py_NewRef(Py_None);
|
||||
self->meta = PyDict_New();
|
||||
self->_sort_key = PyTuple_New(0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -260,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
@ -277,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->users);
|
||||
Py_VISIT(self->_repr_fn);
|
||||
Py_VISIT(self->meta);
|
||||
Py_VISIT(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -294,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->users);
|
||||
Py_CLEAR(self->_repr_fn);
|
||||
Py_CLEAR(self->meta);
|
||||
Py_CLEAR(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
@ -358,195 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__remove_from_list(
|
||||
PyObject* self,
|
||||
PyObject* _ignored) {
|
||||
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||
if (self_ == arg) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
if (!is_node(arg)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||
return nullptr;
|
||||
}
|
||||
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||
if (self->graph != x->graph) {
|
||||
PyErr_SetString(
|
||||
PyExc_AssertionError,
|
||||
"Attempting to move a Node into a different Graph");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
x->remove_from_list();
|
||||
NodeBase* p = self->_prev;
|
||||
p->set_next(x);
|
||||
x->set_prev(p);
|
||||
x->set_next(self);
|
||||
self->set_prev(x);
|
||||
|
||||
// Now compute x.sort_key()
|
||||
const NodeSortKey& psk = x->_prev->sort_key();
|
||||
const NodeSortKey& nsk = x->_next->sort_key();
|
||||
if (psk.size() > nsk.size()) {
|
||||
// prefix = psk[: len(nsk)+1]
|
||||
size_t slice_len = nsk.size() + 1;
|
||||
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||
// last element is idx => increment by 1
|
||||
prefix.back()++;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else if (psk.size() < nsk.size()) {
|
||||
// prefix = nsk[: len(psk)+1]
|
||||
size_t slice_len = psk.size() + 1;
|
||||
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||
// last element is idx => decrement by 1
|
||||
prefix.back()--;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else {
|
||||
// same length => add a 0
|
||||
x->sort_key() = psk;
|
||||
x->sort_key().emplace_back(0);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||
// METH_O => one argument: 'other'
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
bool less = std::lexicographical_compare(
|
||||
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||
if (less)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
// "a > b" is equivalent to "b < a"
|
||||
bool greater = std::lexicographical_compare(
|
||||
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||
if (greater)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___gt__(self, other);
|
||||
}
|
||||
|
||||
// __le__(self, other): Return not (self > other)
|
||||
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___lt__(self, other);
|
||||
}
|
||||
|
||||
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||
// Only used by pickle/__getstate__
|
||||
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
const NodeSortKey& vec = node->sort_key();
|
||||
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||
THPObjectPtr tuple(PyTuple_New(n));
|
||||
if (!tuple) {
|
||||
return nullptr; // Out of memory
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* value = PyLong_FromSsize_t(vec[i]);
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), i, value);
|
||||
}
|
||||
return tuple.release();
|
||||
}
|
||||
|
||||
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||
static int NodeBase_set_sort_key(
|
||||
PyObject* self,
|
||||
PyObject* value,
|
||||
void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
if (!PyTuple_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||
NodeSortKey new_vec;
|
||||
new_vec.reserve(size);
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||
if (val == -1 && PyErr_Occurred()) {
|
||||
return -1;
|
||||
}
|
||||
new_vec.emplace_back(val);
|
||||
}
|
||||
node->sort_key() = std::move(new_vec);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef NodeBase_methods[] = {
|
||||
{"_update_args_kwargs",
|
||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||
METH_FASTCALL,
|
||||
"Internal method: do not call directly."},
|
||||
{"_remove_from_list",
|
||||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||
METH_NOARGS,
|
||||
"Internal method: do not call directly."},
|
||||
{"_prepend",
|
||||
(PyCFunction)(void*)(NodeBase__prepend),
|
||||
METH_O,
|
||||
"Internal method: do not call directly."},
|
||||
{"__lt__",
|
||||
(PyCFunction)(void*)NodeBase___lt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key < other.sort_key"},
|
||||
{"__gt__",
|
||||
(PyCFunction)(void*)NodeBase___gt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key > other.sort_key"},
|
||||
{"__ge__",
|
||||
(PyCFunction)(void*)NodeBase___ge__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key >= other.sort_key"},
|
||||
{"__le__",
|
||||
(PyCFunction)(void*)NodeBase___le__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key <= other.sort_key"},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyGetSetDef NodeBase_getset[] = {
|
||||
{"_sort_key", // attribute name in Python
|
||||
(getter)NodeBase_get_sort_key, // C getter function
|
||||
(setter)NodeBase_set_sort_key, // C setter function
|
||||
(char*)"The sort key as a tuple of ints", // docstring
|
||||
nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
@ -578,7 +361,7 @@ PyTypeObject NodeBaseType = {
|
||||
nullptr, /* tp_iternext */
|
||||
NodeBase_methods, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
NodeBase_getset, /* tp_getset */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
|
@ -385,7 +385,41 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._prepend(x)
|
||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
||||
if self == x:
|
||||
log.debug(
|
||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
||||
)
|
||||
return
|
||||
x._remove_from_list()
|
||||
p = self._prev
|
||||
p._next, x._prev = x, p
|
||||
x._next, self._prev = self, x
|
||||
|
||||
# compute x._sort_key
|
||||
psk = x._prev._sort_key
|
||||
nsk = x._next._sort_key
|
||||
if len(psk) > len(nsk):
|
||||
idx: int
|
||||
*prefix, idx = psk[: len(nsk) + 1]
|
||||
x._sort_key = (*prefix, idx + 1)
|
||||
elif len(psk) < len(nsk):
|
||||
*prefix, idx = nsk[: len(psk) + 1]
|
||||
x._sort_key = (*prefix, idx - 1)
|
||||
else: # same length, increase length by 1
|
||||
x._sort_key = (*psk, 0)
|
||||
|
||||
def __gt__(self, other: "Node") -> bool:
|
||||
return self._sort_key > other._sort_key
|
||||
|
||||
def __lt__(self, other: "Node") -> bool:
|
||||
return self._sort_key < other._sort_key
|
||||
|
||||
def __ge__(self, other: "Node") -> bool:
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other: "Node") -> bool:
|
||||
return self < other or self == other
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def append(self, x: "Node") -> None:
|
||||
@ -396,7 +430,11 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._next._prepend(x)
|
||||
self._next.prepend(x)
|
||||
|
||||
def _remove_from_list(self) -> None:
|
||||
p, n = self._prev, self._next
|
||||
p._next, n._prev = n, p
|
||||
|
||||
@property
|
||||
def args(self) -> tuple[Argument, ...]:
|
||||
|
Reference in New Issue
Block a user