Compare commits

..

3 Commits

Author SHA1 Message Date
5b6cc8215f Change python doc push script to print the undocumented modules 2025-10-21 12:30:49 -07:00
1c43c9cfd0 Update 2025-10-21 12:30:49 -07:00
102e0d5437 Test 2025-10-21 12:30:49 -07:00
59 changed files with 1643 additions and 2463 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.9.5"
"uv==0.8.6"
]
[tool.setuptools]

View File

@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -147,16 +147,15 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -58,10 +58,8 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -89,7 +87,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag(s) with retry
- name: Create and push tag with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -114,23 +112,14 @@ jobs:
return 1
}
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Retry configuration
MAX_RETRIES=5
@ -205,111 +194,31 @@ jobs:
}
}
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
exit 0
else
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
echo "Tag creation failed after all retry attempts"
exit 1
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
else
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi

View File

@ -9,7 +9,6 @@
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -1,378 +0,0 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<uint##bit##_t> { \
using neon_type = uint##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = uint##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_u##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(uint##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
values = vld1q_u##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<uint##bit##_t> loadu( \
const void* ptr, \
uint64_t count = size()); \
void store(void* ptr, uint64_t count = size()) const; \
template <uint64_t mask> \
static Vectorized<uint##bit##_t> blend( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b); \
static Vectorized<uint##bit##_t> blendv( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
const Vectorized<uint##bit##_t>& mask_) { \
return vbslq_u##bit(mask_.values, b, a); \
} \
template <typename step_t> \
static Vectorized<uint##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<uint##bit##_t> set( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
uint64_t count = size()); \
const uint##bit##_t& operator[](uint idx) const = delete; \
uint##bit##_t& operator[](uint idx) = delete; \
Vectorized<uint##bit##_t> abs() const { \
return values; \
} \
Vectorized<uint##bit##_t> real() const { \
return values; \
} \
Vectorized<uint##bit##_t> imag() const { \
return vdupq_n_u##bit(0); \
} \
Vectorized<uint##bit##_t> conj() const { \
return values; \
} \
Vectorized<uint##bit##_t> neg() const { \
return vreinterpretq_u##bit##_s##bit( \
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
} \
uint##bit##_t reduce_add() const { \
return vaddvq_u##bit(values); \
} \
uint##bit##_t reduce_max() const; \
Vectorized<uint##bit##_t> operator==( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator!=( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> operator<( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator<=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> eq( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ne( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> gt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ge( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> lt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> le( \
const Vectorized<uint##bit##_t>& other) const; \
}; \
template <> \
Vectorized<uint##bit##_t> inline operator+( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vaddq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator-( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vsubq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator&( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vandq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator|( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vorrq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator^( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return veorq_u##bit(a, b); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this == other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this != other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this > other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this < other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
}
VEC_UINT_NEON_TEMPLATE(16, 8)
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
return vmaxvq_u8(values);
}
template <>
Vectorized<uint8_t> inline operator*(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmulq_u8(a, b);
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return vmvnq_u8(a);
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
const Vectorized<uint8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<uint8_t> inline minimum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vminq_u8(a, b);
}
template <>
Vectorized<uint8_t> inline maximum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmaxq_u8(a, b);
}
template <uint64_t mask>
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
#define VEC_UINT_NEON_OPS(vl, bit) \
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
values = vdupq_n_u##bit(val); \
} \
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
const void* ptr, uint64_t count) { \
if (count == size()) { \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
} else { \
__at_align__ uint##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const uint##bit##_t*>(ptr), \
count * sizeof(uint##bit##_t)); \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
const { \
if (count == size()) { \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
} else { \
uint##bit##_t tmp_values[size()]; \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
} \
}
VEC_UINT_NEON_OPS(16, 8)
template <typename step_t>
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
uint8_t base,
step_t step) {
const Vectorized<uint8_t> base_vec(base);
const Vectorized<uint8_t> step_vec(step);
const uint8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_u8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<uint8_t> inline operator>>(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return x >> z;
}
template <>
Vectorized<uint8_t> inline operator<<(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return vshlq_u8(a, vreinterpretq_s8_u8(z));
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b,
uint64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<uint8_t> inline operator/(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t y = b;
return x / y;
}
template <>
Vectorized<uint8_t> inline clamp(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min,
const Vectorized<uint8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<uint8_t> inline clamp_max(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<uint8_t> inline clamp_min(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vget_low_u8(src);
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vget_low_u8(src);
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));

View File

@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
}
}
// Short circuit on empty result
if (result.numel() == 0) {
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
// end Lt path
} else {
// No Lt, we use a GEMM instead
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
} // end GEMM path
}
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 1024;
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif

View File

@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int64_t chunk_size,
int chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1645000000,0.1
update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1

1 add_loop_eager compile_time_instruction_count 3184000000 3070000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4595000000 4432000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1096000000 1048000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17720000000 17020000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3152000000 3442000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 8301000000 9239000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4958000000 4820968837 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9990000000 9554000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 8126000000 7618000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -48,89 +48,17 @@ PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,Fa
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
@ -143,9 +71,6 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000

1 Benchmarking Framework Benchmarking Module Name Case Name tag run_backward Execution Time Peak Memory (KB)
48 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 58.529255 0.000000
49 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 54.645077 0.000000
50 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 4.397014 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.739000 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.786000 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.911000 0.000000
51 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 59.243500 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.066000 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.076000 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.225000 0.000000
52 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.947691 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.291000 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.224000 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.912000 0.000000
53 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.925851 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.0240000 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.069000 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.938000 0.000000
54 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.308320 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.091000 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.710000 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.502000 0.000000
55 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.787743 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.863000 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.939000 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.603000 0.000000
56 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 7.978539 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.741000 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.757000 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 8.774000 0.000000
57 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 159.754860 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 165.552000 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 165.755000 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 165.714000 0.000000
58 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 165.360235 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 168.376000 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 169.604000 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 168.428000 0.000000
59 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 3.928136 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.402000 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.567000 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 4.020000 0.000000
60 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 56.413499 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 104.638000 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.335000 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.612000 0.000000
61 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.925090 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.110000 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.389000 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.195000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.989000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.999000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.939000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.980000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.408000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.647000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.476000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.784000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.583000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.083000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.663000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.283000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.986000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.676000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.618000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.982000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.698000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.899000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.741000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.182000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.290000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.744000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.820000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.298000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.988000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.689000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.695000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.978000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.934000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.217000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.215000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.115000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.974000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.828000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.879000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.197000 0.000000
62 PyTorch logical_and logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool short False 78.404254 0.000000
63 PyTorch logical_and logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 5.354032 0.000000
64 PyTorch logical_and logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 54.072783 0.000000
71 PyTorch baddbmm baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16 short False 6.476986 0.000000
72 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32 short False 266.065131 0.000000
73 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16 short False 295.503063 0.000000
PyTorch all all_M1_N1_K1_cpu short False 5.773000 0.000000
PyTorch all all_M64_N64_K64_cpu short False 89.427000 0.000000
PyTorch all all_M64_N64_K128_cpu short False 120.119000 0.000000
74 PyTorch cat cat_sizes(1,1,1)_N2_dim0_cpu short False 4.301950 0.000000
75 PyTorch cat cat_sizes(512,512,2)_N2_dim1_cpu short False 99.093415 0.000000
76 PyTorch cat cat_sizes(128,1024,2)_N2_dim1_cpu short False 96.771578 0.000000

View File

@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32, torch.uint8],
"dtype_two": [torch.int32, torch.uint8],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
},
tags=["short"],
)
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32, torch.uint8],
dtype_two=[torch.int8, torch.int32, torch.uint8],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
tags=["long"],
)

View File

@ -207,6 +207,42 @@ templates_path = [
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -8,8 +8,7 @@ class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype is not supported."
"openreg Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
"In openreg autocast, but the target dtype torch.float32 is not supported.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -67,21 +67,7 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()

View File

@ -288,18 +288,6 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_graph_break(self):
def fn(x):
with torch.fx.traceback.annotate({"pp_stage": 0}):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(10, requires_grad=True)
self.assertEqual(fn(x), opt_fn(x))
if __name__ == "__main__":
run_tests()

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning,
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -6,7 +6,6 @@ import builtins
import collections
import contextlib
import copy
import gc
import functools
import inspect
import io
@ -20,7 +19,6 @@ import traceback
import types
import typing
import unittest
import weakref
import warnings
from math import sqrt
from torch.multiprocessing import Process
@ -1626,25 +1624,6 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x")

View File

@ -7381,10 +7381,6 @@ torch.cuda.synchronize()
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2

View File

@ -8490,14 +8490,6 @@ class TestNNDeviceType(NNTestCase):
y_cuda_contig = pool(x_cuda.contiguous())
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
@onlyCUDA
def test_large_reflect_pad(self, device):
# https://github.com/pytorch/pytorch/issues/165861
x = torch.rand(2**16, 2, device="cuda")
c = F.pad(x, (1, 1), mode="reflect")
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
self.assertEqual(c, c_cpu)
@onlyCUDA
@largeTensorTest("48GB", "cpu")
@largeTensorTest("48GB", "cuda")

View File

@ -247,10 +247,6 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
@ -580,10 +576,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@ -629,10 +621,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@ -670,10 +658,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@ -708,10 +692,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_id(self, dtype) -> None:
N = 512
torch.manual_seed(0)
@ -749,10 +729,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this
@ -778,10 +754,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -798,10 +770,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -840,10 +808,6 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype)
@ -879,10 +843,6 @@ class TestSparseSemiStructuredTraining(TestCase):
)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
@ -893,10 +853,6 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)

View File

@ -2758,12 +2758,6 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -2810,15 +2810,5 @@
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
],
"GB0279": [
{
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
"Context": "str(self)",
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -3502,8 +3502,10 @@ class InstructionTranslatorBase(
if isinstance(excp, Unsupported):
excp.remove_from_stats()
self.push(
VariableTracker.build(self, impl_CONTAINS_OP_fallback).call_function(
self, [left, right], {}
self.inline_user_function_return(
VariableTracker.build(self, impl_CONTAINS_OP_fallback),
[left, right],
{},
)
)
if op == 1:

View File

@ -745,9 +745,9 @@ class BuiltinVariable(VariableTracker):
)
def handler(tx, a, b):
return VariableTracker.build(
tx, polyfill_fn_mapping[op]
).call_function(tx, [a, b], {})
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
)
result.append(((VariableTracker, VariableTracker), handler))
return result
@ -1559,18 +1559,19 @@ class BuiltinVariable(VariableTracker):
)
else:
# Overrides for custom str method
# Pass method as function to call tx.inline_user_function_return
bound_method = str_method.__func__ # type: ignore[attr-defined]
try:
# Only supports certain function types
user_func_variable = VariableTracker.build(tx, bound_method)
user_func_variable = variables.UserFunctionVariable(bound_method)
except AssertionError:
# Won't be able to do inline the str method, return to avoid graph break
log.warning("Failed to create UserFunctionVariable", exc_info=True)
return
# Inline the user function
return user_func_variable.call_function(tx, [arg], {})
return tx.inline_user_function_return(user_func_variable, [arg], {})
elif isinstance(arg, (variables.ExceptionVariable,)):
if len(arg.args) == 0:
value = f"{arg.exc_type}"
@ -1924,8 +1925,8 @@ class BuiltinVariable(VariableTracker):
# VT(foo.__dict__). This simplifies the construction of the new
# dict.
args[0] = args[0].get_forwarded_dict(tx)
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[VariableTracker.build(tx, user_cls), *args],
kwargs,
)
@ -2021,7 +2022,7 @@ class BuiltinVariable(VariableTracker):
):
iter_fn = arg.var_getattr(tx, "__iter__")
if isinstance(iter_fn, variables.UserMethodVariable):
out = iter_fn.call_function(tx, list(args), kwargs)
out = tx.inline_user_function_return(iter_fn, args, kwargs)
if isinstance(out, SetVariable):
return out
return BuiltinVariable(set).call_set(tx, out)

View File

@ -1295,16 +1295,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
def fn_name(self):
return "annotate"
def reconstruct_type(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.fx.traceback.annotate escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -189,8 +189,8 @@ class ItertoolsVariable(VariableTracker):
*args, mutation_type=ValueMutationNew()
)
return VariableTracker.build(tx, polyfills.repeat).call_function(
tx, args, kwargs
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(

View File

@ -181,9 +181,11 @@ class BaseListVariable(VariableTracker):
if not len(args):
raise_args_mismatch(tx, name)
return VariableTracker.build(tx, polyfills.index).call_function(
tx, [self] + list(args), kwargs
)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.index),
[self] + list(args),
kwargs,
)
elif name == "count":
if len(args) != 1:
raise_args_mismatch(tx, name)

View File

@ -542,9 +542,11 @@ class NNModuleVariable(VariableTracker):
args = [self] + args
else:
assert istype(fn, types.FunctionType)
return variables.UserFunctionVariable(
fn, source=fn_source
).call_function(tx, args, kwargs)
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=fn_source),
args,
kwargs,
)
def call_method(
self,
@ -771,8 +773,8 @@ class NNModuleVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
src = AttrSource(AttrSource(self.source, name), "__func__")
return variables.UserFunctionVariable(fn, source=src).call_function(
tx,
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=src),
[self] + list(args),
kwargs,
)
@ -849,8 +851,8 @@ class NNModuleVariable(VariableTracker):
# Inline the function
fn = getattr(module, name).__func__
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx,
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=fn_source),
[self] + args,
kwargs,
)
@ -949,18 +951,13 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# The program can mutate the nn module object but the saved `value`
# will not reflect the mutations. So, trace through the `__iter__`
# function to reflect any tracked mutations.
return (
VariableTracker.build(tx, fn)
.call_function(
tx,
[
self,
],
{},
)
.unpack_var_sequence(tx)
)
return tx.inline_user_function_return(
VariableTracker.build(tx, fn),
[
self,
],
{},
).unpack_var_sequence(tx)
return super().unpack_var_sequence(tx)

View File

@ -1085,8 +1085,8 @@ class TensorVariable(VariableTracker):
if value is not None:
from .. import polyfills
return VariableTracker.build(tx, polyfills.addcmul_inplace).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.addcmul_inplace),
[self, tensor1, tensor2, value],
{},
)

View File

@ -568,16 +568,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
@register(torch.ops.inductor.accumulate_grad_.default)
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
return VariableTracker.build(tx, polyfills.accumulate_grad).call_function(
tx, args, kwargs
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
)
@register(math.radians)
def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs):
if not check_unspec_or_constant_args(args, kwargs):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
return VariableTracker.build(tx, polyfills.radians).call_function(
tx, args, kwargs
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
@register(torch.is_inference_mode_enabled)
@ -829,10 +829,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
_, tx: "InstructionTranslator", *args, **kwargs
):
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
return VariableTracker.build(
tx, polyfills.foreach_lerp_inplace
).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
args,
kwargs,
)
@ -842,10 +840,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# In eager it's more performant to call item() from within the C op implementation
# in compile, it's more performant to not graph break.
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
return VariableTracker.build(
tx, polyfills.foreach_pow_scalar
).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
args,
kwargs,
)
@ -1972,8 +1968,8 @@ class FuncTorchInterpreterVariable(BaseTorchVariable):
if name == "key":
return variables.EnumVariable(self.value.key())
elif name == "process":
return VariableTracker.build(tx, self.value.process.__func__).call_function(
tx,
return tx.inline_user_function_return(
variables.UserFunctionVariable(self.value.process.__func__),
[self] + args,
kwargs,
)

View File

@ -59,7 +59,7 @@ from ..utils import (
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .functions import UserMethodVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -620,7 +620,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
elif isinstance(attr, property):
getter_source = AttrSource(attr_source, "fget")
getter = attr.fget
getter_var = VariableTracker.build(tx, getter, source=getter_source)
getter_var = UserFunctionVariable(getter, source=getter_source)
return getter_var.call_function(tx, [self], {})
elif isinstance(attr, classmethod):

View File

@ -490,8 +490,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
return NullContextVariable(*args, **kwargs)
elif self.value is collections.OrderedDict:
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[self, *args],
kwargs,
)
@ -823,10 +823,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.MappingProxyVariable(args[0])
elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source:
with do_not_convert_to_tracable_parameter():
return VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object
).call_function(
tx,
return tx.inline_user_function_return(
VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object
),
[self, *args],
kwargs,
)
@ -1803,8 +1803,8 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
) -> "VariableTracker":
fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
args = [self] + args
return fn_variable.call_function(
tx,
return tx.inline_user_function_return(
fn_variable,
args,
kwargs,
)

View File

@ -3147,82 +3147,35 @@ class AlgorithmSelectorCache(PersistentCache):
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
example_inputs_extern = []
for input_node in input_nodes:
if unique_example_inputs[input_node.get_name()].is_mkldnn:
example_inputs_extern.append(
unique_example_inputs[input_node.get_name()]
)
else:
base = unique_example_inputs[input_node.get_name()]
base = base if base._base is None else base._base
sizes = tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
example_inputs_extern = [
(
unique_example_inputs[input_node.get_name()]
if unique_example_inputs[input_node.get_name()].is_mkldnn
else torch.as_strided(
unique_example_inputs[input_node.get_name()],
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for size in input_node.get_size()
)
strides = tuple(
V.graph.sizevars.atomically_apply_size_hint(
stride,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for stride in input_node.get_stride()
)
storage_offset = V.graph.sizevars.atomically_apply_size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
# Check if the required storage size exceeds the current storage
# to avoid illegal memory access
needed_size = torch._prims_common.compute_required_storage_length(
sizes, strides, storage_offset
)
current_size = base.storage().size()
if needed_size > current_size:
# Create a new base tensor with sufficient storage
new_base = torch.randn(
needed_size,
dtype=base.dtype,
device=base.device,
requires_grad=base.requires_grad,
)
base = new_base.as_strided(
base.size(), base.stride(), base.storage_offset()
)
example_inputs_extern.append(
torch.as_strided(base, sizes, strides, storage_offset)
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)
)
for input_node in input_nodes
]
out = cls.benchmark_example_value(layout, hint_override=hint_override)
# Also check the output tensor for storage size
out_base = out if out._base is None else out._base
out_offset = V.graph.sizevars.size_hint(layout.offset)
needed_out_size = torch._prims_common.compute_required_storage_length(
out.size(), out.stride(), out_offset
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
current_out_size = out_base.storage().size()
if needed_out_size > current_out_size:
# Create a new base tensor with sufficient storage
new_out_base = torch.randn(
needed_out_size,
dtype=out_base.dtype,
device=out_base.device,
requires_grad=out_base.requires_grad,
)
out_base = new_out_base.as_strided(
out_base.size(), out_base.stride(), out_base.storage_offset()
)
out_extern = torch.as_strided(out_base, out.size(), out.stride(), out_offset)
expected = None
if VERIFY:
choices[0].benchmark(*example_inputs_extern, out=out_extern)
@ -3663,13 +3616,10 @@ class AlgorithmSelectorCache(PersistentCache):
# So we need call as_strided in the end to 'view' the tensor with the correct
# sizes/strides
return AlgorithmSelectorCache.generate_example_value(
tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for size in node.get_size()
V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
tuple(
V.graph.sizevars.atomically_apply_size_hint(
@ -3682,20 +3632,13 @@ class AlgorithmSelectorCache(PersistentCache):
node.get_device(),
node.get_dtype(),
# pyrefly: ignore # missing-attribute
V.graph.sizevars.atomically_apply_size_hint(
node.layout.offset,
node.layout.offset,
V.graph.sizevars.size_hints(
# pyrefly: ignore # bad-argument-type
V.graph.get_allocation_size(node),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
# pyrefly: ignore # bad-argument-type
for size in V.graph.get_allocation_size(node)
),
)
@staticmethod

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
self.fast_dtype = (
torch.get_autocast_dtype(device_type) if dtype is None else dtype
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
self.fast_dtype = dtype
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
@ -243,9 +243,6 @@ class autocast:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name:
necessary_funcs = [
@ -262,55 +259,110 @@ class autocast:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = (
torch.is_autocast_cache_enabled()
if cache_enabled is None
else cache_enabled
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and self.device == "cuda"
and torch.cuda.amp.common.amp_definitely_not_available()
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if cache_enabled is not None:
self._cache_enabled = cache_enabled
device_name = (
self.device
if self.device == self.custom_backend_name
else self.device.upper()
)
if enabled:
# Special case for CUDA AMP and bfloat16 support
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"CUDA is not available or torch_xla is imported. Disabling autocast."
)
enabled = False
elif (
self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.fast_dtype not in device_supported_dtypes:
error_message = (
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
# Special case for MPS bfloat16 support on macOS < 14
if (
self.device == "mps"
and self.fast_dtype == torch.bfloat16
and not torch.backends.mps.is_macos_or_newer(14, 0)
):
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):

View File

@ -235,9 +235,10 @@ class _DerivedObserverOrFakeQuantize(ObserverBase):
from .utils import is_per_channel
if is_per_channel(self.qscheme):
assert self.ch_axis is not None, (
"Must provide a valid ch_axis if qscheme is per channel"
)
if self.ch_axis is None:
raise AssertionError(
"Must provide a valid ch_axis if qscheme is per channel"
)
def forward(self, x: Tensor) -> Tensor:
return x

View File

@ -92,9 +92,10 @@ def channel_range(input, axis=0):
mins = min_over_ndim(input, axis_list)
maxs = max_over_ndim(input, axis_list)
assert mins.size(0) == input.size(axis), (
"Dimensions of resultant channel range does not match size of requested axis"
)
if mins.size(0) != input.size(axis):
raise AssertionError(
"Dimensions of resultant channel range does not match size of requested axis"
)
return maxs - mins

View File

@ -45,7 +45,8 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
**observer_kwargs,
):
super().__init__()
assert quant_min < quant_max, "quant_min must be strictly less than quant_max."
if quant_min >= quant_max:
raise AssertionError("quant_min must be strictly less than quant_max.")
self.quant_min = quant_min
self.quant_max = quant_max
# also pass quant_min and quant_max to observer
@ -56,19 +57,16 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
self.scale = Parameter(torch.tensor([scale]))
self.zero_point = Parameter(torch.tensor([zero_point]))
else:
assert isinstance(channel_len, int) and channel_len > 0, (
"Channel size must be a positive integer."
)
if not (isinstance(channel_len, int) and channel_len > 0):
raise AssertionError("Channel size must be a positive integer.")
self.scale = Parameter(torch.tensor([scale] * channel_len))
self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
self.activation_post_process = observer(**observer_kwargs)
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, (
"quant_min out of bound"
)
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, (
"quant_max out of bound"
)
if not torch.iinfo(self.activation_post_process.dtype).min > quant_min:
raise AssertionError("quant_min out of bound")
if quant_max > torch.iinfo(self.activation_post_process.dtype).max:
raise AssertionError("quant_max out of bound")
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = (

View File

@ -88,9 +88,10 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
>>> lr = nn.LeakyReLU(0.01)
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
"""
assert linear.training == bn.training and bn.training == leaky_relu.training, (
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
)
if linear.training != bn.training or bn.training != leaky_relu.training:
raise AssertionError(
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
)
if is_qat:
raise NotImplementedError(

View File

@ -164,10 +164,11 @@ def remove_boolean_dispatch_from_name(p) -> Any:
return "torch.nn.functional.adaptive_max_pool2d"
elif p is F.adaptive_max_pool3d:
return "torch.nn.functional.adaptive_max_pool3d"
assert "boolean_dispatch" not in str(p), (
f"{p} does not have a human readable representation in "
+ "quantization documentation"
)
if "boolean_dispatch" in str(p):
raise AssertionError(
f"{p} does not have a human readable representation in "
+ "quantization documentation"
)
return p
@ -300,7 +301,8 @@ def _get_fuser_method_in_reversed_nested_tuple_format(
The first argument of a fuser method is always `is_qat` and is not affected
in the conversion. We currently only support functions with 3 or 4 arguments.
"""
assert config.fuser_method is not None
if config.fuser_method is None:
raise AssertionError("config.fuser_method must be provided")
if config._pattern_complex_format is not None:
return config.fuser_method
if not isinstance(config.pattern, tuple):

View File

@ -175,9 +175,10 @@ class FakeQuantize(FakeQuantizeBase):
super().__init__()
# Populate quant_min/quant_max to observer_kwargs if valid
if quant_min is not None and quant_max is not None:
assert quant_min <= quant_max, (
"quant_min must be less than or equal to quant_max"
)
if quant_min > quant_max:
raise AssertionError(
"quant_min must be less than or equal to quant_max"
)
dtype = observer_kwargs.get("dtype", torch.quint8)
if hasattr(observer, "p"):
# In case observer is _PartialWrapper, dtype can be stored in
@ -186,9 +187,11 @@ class FakeQuantize(FakeQuantizeBase):
"dtype", dtype
)
# pyrefly: ignore # bad-argument-type
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
if torch.iinfo(dtype).min > quant_min:
raise AssertionError("quant_min out of bound")
# pyrefly: ignore # bad-argument-type
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
if quant_max > torch.iinfo(dtype).max:
raise AssertionError("quant_max out of bound")
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic
self.activation_post_process = observer(**observer_kwargs)
@ -210,11 +213,12 @@ class FakeQuantize(FakeQuantizeBase):
if hasattr(self.activation_post_process, "ch_axis")
else -1
)
assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), (
"Only per channel and per tensor quantization are supported in fake quantize"
+ " got qscheme: "
+ str(self.qscheme)
)
if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)):
raise AssertionError(
"Only per channel and per tensor quantization are supported in fake quantize"
+ " got qscheme: "
+ str(self.qscheme)
)
self.is_per_channel = _is_per_channel(self.qscheme)
@torch.jit.export
@ -295,7 +299,10 @@ class FakeQuantize(FakeQuantizeBase):
if name == "scale":
self.scale.resize_(val.shape)
else:
assert name == "zero_point"
if name != "zero_point":
raise AssertionError(
"Expected 'zero_point' but got different state key"
)
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
@ -303,7 +310,10 @@ class FakeQuantize(FakeQuantizeBase):
if name == "scale":
self.scale.copy_(val)
else:
assert name == "zero_point"
if name != "zero_point":
raise AssertionError(
"Expected 'zero_point' but got different state key"
)
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
@ -329,17 +339,19 @@ class FixedQParamsFakeQuantize(FakeQuantize):
# TODO: rename observer to observer_ctr
def __init__(self, observer):
super().__init__(observer=observer)
assert type(self.activation_post_process) is FixedQParamsObserver, (
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
)
if type(self.activation_post_process) is not FixedQParamsObserver:
raise AssertionError(
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
)
self._observer_ctr = observer
self.scale = self.activation_post_process.scale
self.zero_point = self.activation_post_process.zero_point
assert _is_per_tensor(self.qscheme), (
"Only per tensor quantization is supported"
+ " FixedQParamsFakeQuantize module, got qscheme:"
+ str(self.qscheme)
)
if not _is_per_tensor(self.qscheme):
raise AssertionError(
"Only per tensor quantization is supported"
+ " FixedQParamsFakeQuantize module, got qscheme:"
+ str(self.qscheme)
)
@torch.jit.export
def calculate_qparams(self): # type: ignore[override]
@ -382,12 +394,13 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
**observer_kwargs: Any,
) -> None:
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
assert isinstance(
if not isinstance(
self.activation_post_process,
(MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
), (
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
)
):
raise AssertionError(
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
)
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
self.is_symmetric_quant = _is_symmetric_quant(

View File

@ -35,9 +35,10 @@ def fuse_conv_bn(is_qat, conv, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert conv.training == bn.training, (
"Conv and BN both must be in the same mode (train or eval)."
)
if conv.training != bn.training:
raise AssertionError(
"Conv and BN both must be in the same mode (train or eval)."
)
fused_module_class_map = {
nn.Conv1d: nni.ConvBn1d,
@ -46,13 +47,18 @@ def fuse_conv_bn(is_qat, conv, bn):
}
if is_qat:
assert bn.num_features == conv.out_channels, (
"Output channel of Conv2d must match num_features of BatchNorm2d"
)
assert bn.affine, "Only support fusing BatchNorm2d with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
if bn.num_features != conv.out_channels:
raise AssertionError(
"Output channel of Conv2d must match num_features of BatchNorm2d."
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm2d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
fused_module_class = fused_module_class_map.get((type(conv)), None)
if fused_module_class is not None:
return fused_module_class(conv, bn)
@ -81,9 +87,10 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
"""
assert conv.training == bn.training == relu.training, (
"Conv and BN both must be in the same mode (train or eval)."
)
if not (conv.training == bn.training == relu.training):
raise AssertionError(
"Conv and BN both must be in the same mode (train or eval)."
)
fused_module: Optional[type[nn.Sequential]] = None
if is_qat:
map_to_fused_module_train = {
@ -91,13 +98,18 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
nn.Conv2d: nni.ConvBnReLU2d,
nn.Conv3d: nni.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, (
"Output channel of Conv must match num_features of BatchNorm"
)
assert bn.affine, "Only support fusing BatchNorm with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm with tracking_running_stats set to True"
)
if bn.num_features != conv.out_channels:
raise AssertionError(
"Output channel of Conv2d must match num_features of BatchNorm2d"
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm2d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
fused_module = map_to_fused_module_train.get(type(conv), None)
if fused_module is not None:
return fused_module(conv, bn, relu)
@ -134,18 +146,24 @@ def fuse_linear_bn(is_qat, linear, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_linear_bn(m1, b1)
"""
assert linear.training == bn.training, (
"Linear and BN both must be in the same mode (train or eval)."
)
if linear.training != bn.training:
raise AssertionError(
"Linear and BN both must be in the same mode (train or eval)."
)
if is_qat:
assert bn.num_features == linear.out_features, (
"Output features of Linear must match num_features of BatchNorm1d"
)
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
)
if bn.num_features != linear.out_features:
raise AssertionError(
"Output features of Linear must match num_features of BatchNorm1d"
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm1d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
)
return nni.LinearBn1d(linear, bn)
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
@ -167,9 +185,10 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_convtranspose_bn(m1, b1)
"""
assert convt.training == bn.training, (
"ConvTranspose and BN both must be in the same mode (train or eval)."
)
if convt.training != bn.training:
raise AssertionError(
"ConvTranspose and BN both must be in the same mode (train or eval)."
)
if is_qat:
raise Exception( # noqa: TRY002
@ -224,7 +243,8 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
_DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping
)
fuser_method = all_mappings.get(op_list, None)
assert fuser_method is not None, f"did not find fuser method for: {op_list} "
if fuser_method is None:
raise AssertionError(f"did not find fuser method for: {op_list} ")
return fuser_method
@ -289,5 +309,6 @@ def get_fuser_method_new(
fuser_method = fuser_method_mapping.get(op_pattern)
if fuser_method is not None:
break
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
if fuser_method is None:
raise AssertionError(f"did not find fuser method for: {op_pattern} ")
return fuser_method

View File

@ -249,17 +249,17 @@ class UniformQuantizationObserverBase(ObserverBase):
)
self.reduce_range = reduce_range
self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
assert self.qscheme in (
if self.qscheme not in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
), (
"Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine, \
per_channel_symmetric and per_channel_float_qparams quantization scheme"
)
):
raise AssertionError(
"Default Observer only works for per_tensor_affine, per_tensor_symmetric, "
"per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme"
)
_ALLOWED_DTYPES = (
torch.qint8,
@ -275,9 +275,10 @@ class UniformQuantizationObserverBase(ObserverBase):
torch.uint16,
)
assert self.dtype in _ALLOWED_DTYPES, (
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
)
if self.dtype not in _ALLOWED_DTYPES:
raise AssertionError(
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
)
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type
@ -336,12 +337,12 @@ class UniformQuantizationObserverBase(ObserverBase):
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
assert quant_min <= 0 <= quant_max, (
"Used-specified quantization range must include 0."
)
assert quant_min < quant_max, (
"qmin must be strictly less than qmax for user-specified quantization range."
)
if not quant_min <= 0 <= quant_max:
raise AssertionError("Used-specified quantization range must include 0.")
if quant_min >= quant_max:
raise AssertionError(
"qmin must be strictly less than qmax for user-specified quantization range."
)
@torch.jit.export
def _calculate_qparams(
@ -1131,7 +1132,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
caffe2/quantization/server/norm_minimization.cc
"""
assert self.histogram.size()[0] == self.bins, "bins mismatch"
if self.histogram.size()[0] != self.bins:
raise AssertionError("bins mismatch")
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
@ -1252,8 +1254,10 @@ class HistogramObserver(UniformQuantizationObserverBase):
return transformed_orig_hist + update_hist
# We assume the update_hist is already in the target range, we will map the orig_max to it
assert update_min <= orig_min
assert update_max >= orig_max
if update_min > orig_min:
raise AssertionError("update_min must be <= orig_min")
if update_max < orig_max:
raise AssertionError("update_max must be >= orig_max")
# Now we need to turn the old_histogram, into the range of the new histogram
transformed_orig_hist = self._upscale_histogram(
@ -1273,9 +1277,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
assert min_val.numel() == 1 and max_val.numel() == 1, (
"histogram min/max values must be scalar."
)
if min_val.numel() != 1 or max_val.numel() != 1:
raise AssertionError("histogram min/max values must be scalar.")
new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type]
self.histogram.detach_().resize_(new_histogram.shape)
self.histogram.copy_(new_histogram)
@ -1350,10 +1353,11 @@ class HistogramObserver(UniformQuantizationObserverBase):
return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
[0], device=self.min_val.device.type
)
assert self.bins == len(self.histogram), (
"The number of bins in histogram should be equal to the number of bins "
"supplied while making this observer"
)
if self.bins != len(self.histogram):
raise AssertionError(
"The number of bins in histogram should be equal to the number of bins "
"supplied while making this observer"
)
new_min, new_max = self._non_linear_param_search()
@ -1785,9 +1789,10 @@ def get_block_size(
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
assert isinstance(granularity, Granularity), (
"Please provide an instance of Granularity, not subclass of it"
)
if not isinstance(granularity, Granularity):
raise AssertionError(
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
@ -1797,9 +1802,10 @@ def get_block_size(
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerGroup):
assert len(input_shape) == 2, (
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
if len(input_shape) != 2:
raise AssertionError(
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = [1] * len(input_shape)
@ -1836,8 +1842,8 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
**kwargs,
):
super().__init__()
assert granularity is not None, "granularity is None"
if granularity is None:
raise AssertionError("granularity is None")
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity = granularity
@ -1875,10 +1881,10 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
from torch.ao.quantization.fx.utils import create_getattr_from_value
with model.graph.inserting_before(observer_node):
assert self.block_size is not None, "Expecting block_size to be populated"
assert self.original_dtype is not None, (
"Expecting original_dtype to be populated"
)
if self.block_size is None:
raise AssertionError("Expecting block_size to be populated")
if self.original_dtype is None:
raise AssertionError("Expecting original_dtype to be populated")
if hasattr(self, "is_dynamic") and self.is_dynamic:
choose_qparams_affine = model.graph.call_function(
torch.ops.pt2e_quant.choose_qparams_affine,

View File

@ -565,9 +565,10 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
torch.ao.quantization.MovingAveragePerChannelMinMaxObserver,
),
)
assert not is_per_channel, (
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
)
if is_per_channel:
raise AssertionError(
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
)
if sys.version_info < (3, 12):
@ -599,7 +600,8 @@ def _add_module_to_qconfig_obs_ctr(
return qconfig
def get_factory_kwargs_based_on_module_device():
assert isinstance(module, torch.nn.Module)
if not isinstance(module, torch.nn.Module):
raise AssertionError("module must be an instance of torch.nn.Module")
devices = {p.device for p in module.parameters()} | {
p.device for p in module.buffers()
}
@ -671,7 +673,10 @@ def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
if q1 is None or q2 is None:
return q1 == q2
else:
assert q1 is not None and q2 is not None
if q1 is None or q2 is None:
raise AssertionError(
"Both q1 and q2 must be non-None for qconfig comparison"
)
try:
# Qconfig weight and activation can be either a partial wrapper,
# or an observer class. Special handling is required (above) for

View File

@ -252,10 +252,11 @@ def get_static_quant_module_class(
additional_static_quant_mapping,
)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, (
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
if static_quant_module_class is None:
raise AssertionError(
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
return copy.deepcopy(static_quant_module_class)
@ -272,10 +273,11 @@ def get_dynamic_quant_module_class(
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping
)
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
assert dynamic_quant_module_class is not None, (
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
if dynamic_quant_module_class is None:
raise AssertionError(
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
return copy.deepcopy(dynamic_quant_module_class)
@ -344,9 +346,10 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
"""Get the quantized operator corresponding to the float operator"""
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
assert quantized_op is not None, (
f"Operator {str(float_op)} does not have corresponding quantized op"
)
if quantized_op is None:
raise AssertionError(
f"Operator {str(float_op)} does not have corresponding quantized op"
)
return quantized_op

View File

@ -158,9 +158,10 @@ def _observer_forward_pre_hook(self, input):
def _register_activation_post_process_hook(module, pre_hook=False):
assert hasattr(module, "activation_post_process"), (
"Expect activation_post_process attribute already attached to the module"
)
if not hasattr(module, "activation_post_process"):
raise AssertionError(
"Expect activation_post_process attribute already attached to the module"
)
if pre_hook:
module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True)
else:
@ -198,9 +199,10 @@ def _add_observer_(
# respect device affinity when adding observers
if device is None:
devices = _get_unique_devices_(module)
assert len(devices) <= 1, (
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
)
if len(devices) > 1:
raise AssertionError(
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
def get_activation_post_process(qconfig, device, special_act_post_process=None):
@ -243,9 +245,10 @@ def _add_observer_(
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
):
if needs_observation(child):
assert hasattr(child, "activation_post_process"), (
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
)
if not hasattr(child, "activation_post_process"):
raise AssertionError(
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
)
child.activation_post_process = get_activation_post_process(
child.qconfig, device
)
@ -584,7 +587,8 @@ def prepare_qat(model, mapping=None, inplace=False):
is mutated
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
assert model.training, "prepare_qat only works on models in training mode"
if not model.training:
raise AssertionError("prepare_qat only works on models in training mode")
if mapping is None:
mapping = get_default_qat_module_mappings()
@ -760,7 +764,10 @@ def swap_module(
elif type_before_parametrizations(mod) in mapping:
qmod = mapping[type_before_parametrizations(mod)]
if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
assert mod.qconfig is not None
if mod.qconfig is None:
raise AssertionError(
"module qconfig must not be None when swapping to reference module"
)
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
weight_qparams = get_qparam_dict(weight_post_process)
@ -787,11 +794,13 @@ def swap_module(
# respect device affinity when swapping modules
devices = _get_unique_devices_(mod)
assert len(devices) <= 1 or (
len(devices) == 2 and torch.device("meta") in devices
), (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
if not (
len(devices) <= 1
or (len(devices) == 2 and torch.device("meta") in devices)
):
raise AssertionError(
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
if device:
new_mod.to(device)

View File

@ -157,12 +157,12 @@ def _convert_ondevice_jit(
model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC
):
_check_is_script_module(model)
assert quant_type == QuantType.DYNAMIC, (
"This API, while should work for static quant, is only tested for dynamic quant."
)
assert not method_name.startswith("observe_"), (
"Pass in valid method to be quantized, e.g. forward"
)
if quant_type != QuantType.DYNAMIC:
raise AssertionError(
"This API, while should work for static quant, is only tested for dynamic quant."
)
if method_name.startswith("observe_"):
raise AssertionError("Pass in valid method to be quantized, e.g. forward")
observe_method_name = "observe_" + method_name
quantize_method_name = "quantize_" + method_name
model_c = model._c
@ -230,12 +230,14 @@ def _quantize_jit(
model = prepare_dynamic_jit(model, qconfig_dict, inplace)
model = convert_dynamic_jit(model, True, debug)
else:
assert run_fn, (
"Must provide calibration function for post training static quantization"
)
assert run_args, (
"Must provide calibration dataset for post training static quantization"
)
if not run_fn:
raise AssertionError(
"Must provide calibration function for post training static quantization"
)
if not run_args:
raise AssertionError(
"Must provide calibration dataset for post training static quantization"
)
model = prepare_jit(model, qconfig_dict, inplace)
run_fn(model, *run_args)
model = convert_jit(model, True, debug)

View File

@ -263,7 +263,10 @@ def _is_quantized_op_pt2e(node: torch.fx.Node):
# The node has not been annotated, directly return False
return False
quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
if not isinstance(quantization_annotation, _X86InductorQuantizationAnnotation):
raise AssertionError(
"quantization_annotation must be an _X86InductorQuantizationAnnotation"
)
return quantization_annotation._is_output_of_quantized_pattern
@ -428,20 +431,22 @@ class X86InductorQuantizer(Quantizer):
if qat_state is None:
qat_state = qconfig.is_qat
else:
assert qat_state == qconfig.is_qat, (
f"All non-None quantization configs should have the same `is_qat`,"
f"but got {qat_state} and {qconfig.is_qat}."
)
if qat_state != qconfig.is_qat:
raise AssertionError(
f"All non-None quantization configs should have the same `is_qat`,"
f"but got {qat_state} and {qconfig.is_qat}."
)
# Query the `is_dynamic` state
input_activation_spec = qconfig.input_activation
if input_activation_spec is not None:
if dynamic_state is None:
dynamic_state = input_activation_spec.is_dynamic
else:
assert dynamic_state == input_activation_spec.is_dynamic, (
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
)
if dynamic_state != input_activation_spec.is_dynamic:
raise AssertionError(
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
)
return _CurrentQuantizationMode(
qat_state=qat_state, dynamic_state=dynamic_state
)
@ -567,10 +572,12 @@ class X86InductorQuantizer(Quantizer):
return
input_qspec_map = {}
input_node = conv_node.args[0]
assert isinstance(input_node, Node)
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
weight_node = conv_node.args[1]
assert isinstance(weight_node, Node)
if not isinstance(weight_node, Node):
raise AssertionError("weight_node must be a FX Node")
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
if isinstance(bias_node, Node):
@ -598,18 +605,23 @@ class X86InductorQuantizer(Quantizer):
_annotate_nodes_not_quantize(linear_node)
return
input_qspec_map = {}
assert linear_node.target == torch.ops.aten.linear.default
if linear_node.target != torch.ops.aten.linear.default:
raise AssertionError(
"linear_node.target must be torch.ops.aten.linear.default"
)
has_bias = len(linear_node.args) == 3
input_index = 0
weight_index = 1
bias_index = 2
input_node = linear_node.args[input_index]
assert isinstance(input_node, Node)
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
weight_node = linear_node.args[weight_index]
assert isinstance(weight_node, Node)
if not isinstance(weight_node, Node):
raise AssertionError("weight_node must be a FX Node")
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
bias_node = linear_node.args[bias_index] if has_bias else None
@ -637,7 +649,8 @@ class X86InductorQuantizer(Quantizer):
if len(partition.output_nodes) > 1:
raise ValueError("Input partition has more than one output node")
output_node = partition.output_nodes[0]
assert isinstance(output_node, Node)
if not isinstance(output_node, Node):
raise AssertionError("output_node must be a FX Node")
output_node_list.append(output_node)
if len(output_node_list) != len(partition_list):
raise ValueError(
@ -666,7 +679,8 @@ class X86InductorQuantizer(Quantizer):
conv_gemm_node_idx = 1
extra_input_node_idx = 0
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
assert isinstance(extra_input_node, Node)
if not isinstance(extra_input_node, Node):
raise AssertionError("extra_input_node must be a FX Node")
return conv_gemm_node_idx, extra_input_node_idx
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@ -1123,7 +1137,8 @@ class X86InductorQuantizer(Quantizer):
if conv_node != binary_node.args[conv_node_idx]:
raise ValueError(f"{conv_node} doesn't match input of binary node")
extra_input_node = binary_node.args[extra_input_node_idx]
assert isinstance(conv_node, Node)
if not isinstance(conv_node, Node):
raise AssertionError("conv_node must be a FX Node")
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
@ -1237,7 +1252,8 @@ class X86InductorQuantizer(Quantizer):
return
input_node = maxpool_node.args[0]
assert isinstance(input_node, Node)
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
@ -1254,11 +1270,14 @@ class X86InductorQuantizer(Quantizer):
return
cat_node = node
input_nodes = cat_node.args[0]
assert isinstance(input_nodes, Sequence)
if not isinstance(input_nodes, Sequence):
raise AssertionError("input_nodes must be a Sequence of FX Nodes")
first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(cat_node, Node)
if not isinstance(first_input_node, Node):
raise AssertionError("first_input_node must be a FX Node")
if not isinstance(cat_node, Node):
raise AssertionError("cat_node must be a FX Node")
input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, cat_node)
@ -1267,7 +1286,8 @@ class X86InductorQuantizer(Quantizer):
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
# There has the case of cat same nodes: torch.cat([input0, input0], 1)
assert isinstance(input_node, Node)
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
@ -1405,8 +1425,10 @@ class X86InductorQuantizer(Quantizer):
):
# Annotate the output_qspec of getitem_node
input_act = maxpool_node.args[0]
assert isinstance(input_act, Node)
assert isinstance(maxpool_node, Node)
if not isinstance(input_act, Node):
raise AssertionError("input_act must be a FX Node")
if not isinstance(maxpool_node, Node):
raise AssertionError("maxpool_node must be a FX Node")
edge_or_node = (input_act, maxpool_node)
maxpool_node_quantization_annotation.output_qspec = (
SharedQuantizationSpec(edge_or_node)
@ -1534,7 +1556,8 @@ class X86InductorQuantizer(Quantizer):
raise ValueError(
f"{linear_node} doesn't match input of binary node"
)
assert isinstance(linear_node, Node)
if not isinstance(linear_node, Node):
raise AssertionError("linear_node must be a FX Node")
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default

View File

@ -347,9 +347,8 @@ class XNNPACKQuantizer(Quantizer):
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert quantization_config is not None, (
" quantization_config == None is not supported yet"
)
if quantization_config is None:
raise AssertionError("quantization_config == None is not supported yet")
self.module_name_config[module_name] = quantization_config
return self

View File

@ -121,10 +121,13 @@ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config.input_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.input_activation
assert quantization_spec.qscheme in [
if quantization_spec.qscheme not in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
]:
raise AssertionError(
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
)
return quantization_spec
@ -134,17 +137,21 @@ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config.output_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.output_activation
assert quantization_spec.qscheme in [
if quantization_spec.qscheme not in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
]:
raise AssertionError(
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
)
return quantization_spec
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config is None:
raise AssertionError("quantization_config must not be None")
if quantization_config.weight is None:
return None
quantization_spec: QuantizationSpec = quantization_config.weight
@ -162,13 +169,15 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config is None:
raise AssertionError("quantization_config must not be None")
if quantization_config.bias is None:
return None
quantization_spec: QuantizationSpec = quantization_config.bias
assert quantization_spec.dtype == torch.float, (
"Only float dtype for bias is supported for bias right now"
)
if quantization_spec.dtype != torch.float:
raise AssertionError(
"Only float dtype for bias is supported for bias right now"
)
return quantization_spec
@ -253,11 +262,13 @@ def _annotate_linear_relu(
input_qspec_map = {}
input_act = linear_node.args[0]
assert isinstance(input_act, Node)
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
input_qspec_map[input_act] = input_act_qspec
weight = linear_node.args[1]
assert isinstance(weight, Node)
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
input_qspec_map[weight] = weight_qspec
# adding weight node to the partition as well
@ -303,11 +314,13 @@ def _annotate_conv(
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
@ -362,11 +375,13 @@ def _do_annotate_conv_relu(
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
@ -635,8 +650,10 @@ def _annotate_gru_io_only(
# subgraph
input_act = input_nodes[0]
input_act_user = next(iter(input_act.users.keys()))
assert isinstance(input_act, Node)
assert isinstance(input_act_user, Node)
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
if not isinstance(input_act_user, Node):
raise AssertionError("input activation user must be a FX Node")
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: get_input_act_qspec(quantization_config),
@ -646,8 +663,10 @@ def _annotate_gru_io_only(
hidden_state = input_nodes[1]
hidden_state_user = next(iter(hidden_state.users.keys()))
assert isinstance(hidden_state, Node)
assert isinstance(hidden_state_user, Node)
if not isinstance(hidden_state, Node):
raise AssertionError("hidden state must be a FX Node")
if not isinstance(hidden_state_user, Node):
raise AssertionError("hidden state user must be a FX Node")
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
hidden_state: get_input_act_qspec(quantization_config),
@ -655,7 +674,8 @@ def _annotate_gru_io_only(
_annotated=True,
)
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
if len(output_nodes) != 2:
raise AssertionError("expecting GRU to have two outputs")
for output in output_nodes:
output.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config),
@ -691,7 +711,8 @@ def _annotate_adaptive_avg_pool2d(
annotated_partitions.append(partition.nodes)
input_act = pool_node.args[0]
assert isinstance(input_act, Node)
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
# only annotate input output sharing operator
# when the output of the input node is annotated

View File

@ -214,7 +214,8 @@ def to_underlying_dtype(qdtype):
torch.float8_e5m2: torch.float8_e5m2,
torch.float8_e4m3fn: torch.float8_e4m3fn,
}
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
if qdtype not in DTYPE_MAPPING:
raise AssertionError("Unsupported dtype: " + str(qdtype))
return DTYPE_MAPPING[qdtype]
@ -269,21 +270,24 @@ def get_swapped_custom_module_class(
"""
quant_type = get_quant_type(qconfig)
class_mapping = custom_module_class_mapping.get(quant_type, {})
assert type(custom_module) in class_mapping, (
"did not find corresponding observed "
f"module class for {type(custom_module)} in mapping: {class_mapping}"
)
if type(custom_module) not in class_mapping:
raise AssertionError(
"did not find corresponding observed "
f"module class for {type(custom_module)} in mapping: {class_mapping}"
)
return class_mapping[type(custom_module)]
def activation_dtype(qconfig):
assert qconfig is not None
if qconfig is None:
raise AssertionError("qconfig must be provided to determine activation dtype")
activation = qconfig.activation()
return activation.dtype
def weight_dtype(qconfig):
assert qconfig is not None
if qconfig is None:
raise AssertionError("qconfig must be provided to determine weight dtype")
weight = qconfig.weight()
return weight.dtype
@ -377,7 +381,8 @@ def get_qconfig_dtypes(qconfig):
r"""returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_is_dynamic)
"""
assert qconfig is not None
if qconfig is None:
raise AssertionError("qconfig must be provided to extract dtypes")
activation = qconfig.activation()
weight = qconfig.weight()
act_is_dynamic = getattr(activation, "is_dynamic", False)
@ -385,7 +390,8 @@ def get_qconfig_dtypes(qconfig):
def get_quant_type(qconfig):
assert qconfig is not None
if qconfig is None:
raise AssertionError("qconfig must be provided to determine quant type")
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [
@ -440,11 +446,11 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
return False
assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
if min_val > max_val:
raise AssertionError(f"min {min_val} should be less than max {max_val}")
else:
assert torch.all(min_val <= max_val), (
f"min {min_val} should be less than max {max_val}"
)
if torch.any(min_val > max_val):
raise AssertionError(f"min {min_val} should be less than max {max_val}")
return True
@ -479,13 +485,15 @@ def calculate_qmin_qmax(
qrange_len = initial_quant_max - initial_quant_min + 1
if dtype in [torch.qint8, torch.int8]:
assert 0 < qrange_len <= 256, (
"quantization range should be positive and not exceed the maximum bit range (=256)."
)
if not (0 < qrange_len <= 256):
raise AssertionError(
"quantization range should be positive and not exceed the maximum bit range (=256)."
)
elif dtype in [torch.qint32, torch.int32]:
assert 0 < qrange_len <= 2**32, (
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
)
if not (0 < qrange_len <= 2**32):
raise AssertionError(
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
)
if reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
else:
@ -633,12 +641,12 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
assert quant_min <= 0 <= quant_max, (
"Used-specified quantization range must include 0."
)
assert quant_min < quant_max, (
"qmin must be strictly less than qmax for user-specified quantization range."
)
if not (quant_min <= 0 <= quant_max):
raise AssertionError("Used-specified quantization range must include 0.")
if quant_min >= quant_max:
raise AssertionError(
"qmin must be strictly less than qmax for user-specified quantization range."
)
# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
@ -810,10 +818,11 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
)
devices = {torch.device("cpu")}
""
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
if len(devices) > 1:
raise AssertionError(
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device

View File

@ -419,7 +419,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
}
// Do not call this directly, use ProcessGroup::setGroupName instead.
virtual void setGroupUid(const std::string& pg_uid) {
void setGroupUid(const std::string& pg_uid) {
pg_uid_ = pg_uid;
}

View File

@ -1,15 +1,11 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -167,41 +163,7 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
PyObject* _sort_key;
};
static PyObject* NodeBase_new(
@ -211,8 +173,6 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -241,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -260,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -277,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -294,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -358,195 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -578,7 +361,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -385,7 +385,41 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
self._prepend(x)
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -396,7 +430,11 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next._prepend(x)
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
@property
def args(self) -> tuple[Argument, ...]: