Compare commits

..

5 Commits

Author SHA1 Message Date
e3d00beddd Fix triu_/tril_ overlap handling 2025-10-21 07:54:24 -07:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
1009790ad8 [pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321 Better error handling in torch/csrc/jit/frontend/* (#165213)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 13:54:59 +00:00
23c55c5b66 [Code Clean]Replace assert statements with explicit if/raise patterns (#165735)
Fix part of #164878

Replace 75 assert statements with explicit if/raise patterns in `torch/ao/ns` , include:

- `torch/ao/ns/_numeric_suite_fx.py`  - 5 asserts

- `torch/ao/ns/fx/graph_matcher.py` - 6 asserts

- `torch/ao/ns/fx/graph_passes.py` -12 asserts

- `torch/ao/ns/fx/n_shadows_utils.py` - 20 asserts

- `torch/ao/ns/fx/pattern_utils.py` - 2 asserts

- `torch/ao/ns/fx/utils.py` - 21 asserts

- `torch/ao/ns/fx/weight_utils.py` - 19 asserts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165735
Approved by: https://github.com/albanD
2025-10-21 11:21:57 +00:00
44 changed files with 486 additions and 751 deletions

View File

@ -86,10 +86,6 @@ else
fi
fi
if [[ "$BUILD_ENVIRONMENT" == *zen* ]]; then
export USE_ZENDNN=1
fi
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
export USE_MKLDNN_ACL=1

View File

@ -54,17 +54,12 @@ self-hosted-runner:
- windows-11-arm64
- windows-11-arm64-preview
# Organization-wide AMD-hosted runners
# MI2xx non-ARC runners
# MI2xx runners
- linux.rocm.gpu
- linux.rocm.gpu.mi250
- linux.rocm.gpu.2
- linux.rocm.gpu.4
- linux.rocm.gpu.mi250
- linux.rocm.gpu.gfx1100
# MI2xx ARC runners
- linux.rocm.gpu.mi250.1
- linux.rocm.gpu.mi250.2
- linux.rocm.gpu.mi250.4
# gfx942 ARC runners
# gfx942 runners
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4

View File

@ -80,7 +80,7 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-zen-py3.10-gcc11-build
build-environment: linux-jammy-py3.10-gcc11-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
@ -106,7 +106,7 @@ jobs:
needs: inductor-build
if: github.event.schedule == '0 7 * * *'
with:
build-environment: linux-jammy-zen-py3.10-gcc11-build
build-environment: linux-jammy-py3.10-gcc11-build
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
@ -122,7 +122,7 @@ jobs:
uses: ./.github/workflows/_linux-test.yml
needs: inductor-build
with:
build-environment: linux-jammy-zen-py3.10-gcc11-build
build-environment: linux-jammy-py3.10-gcc11-build
dashboard-tag: training-${{ inputs.training || 'false' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'true' }}-aotinductor-${{ inputs.aotinductor || 'true' }}-freezing-${{ inputs.freezing || 'true' }}
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}

View File

@ -36,12 +36,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
secrets: inherit

3
.gitmodules vendored
View File

@ -132,6 +132,3 @@
[submodule "third_party/aiter"]
path = third_party/aiter
url = https://github.com/ROCm/aiter.git
[submodule "third_party/ZenDNN"]
path = third_party/ZenDNN
url = https://github.com/amd/ZenDNN.git

View File

@ -82,7 +82,6 @@ include_patterns = [
'aten/src/ATen/native/mkldnn/xpu/**/*.cpp',
'aten/src/ATen/native/Tensor*.h',
'aten/src/ATen/native/Tensor*.cpp',
'aten/src/ATen/native/zendnn/*.*',
'c10/**/*.h',
'c10/**/*.cpp',
'torch/csrc/**/*.h',

View File

@ -205,11 +205,6 @@ filegroup(
srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
)
filegroup(
name = "aten_native_zendnn_cpp",
srcs = glob(["aten/src/ATen/native/zendnn/*.cpp"]),
)
filegroup(
name = "aten_base_vulkan",
srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
@ -290,7 +285,6 @@ header_template_rule(
"@AT_BLAS_USE_CBLAS_DOT@": "1",
"@AT_KLEIDIAI_ENABLED@": "0",
"@AT_USE_EIGEN_SPARSE@": "0",
"@AT_ZENDNN_ENABLED@": "0",
},
)
@ -371,7 +365,6 @@ cc_library(
":aten_native_sparse_cpp",
":aten_native_transformers_cpp",
":aten_native_xnnpack",
":aten_native_zendnn_cpp",
":aten_src_ATen_config",
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
copts = ATEN_COPTS,

View File

@ -326,21 +326,6 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN})
cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN"
OFF)
option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF)
# currently ZenDNN is kept off and enabled only through user setting on X86_64/AMD64
option(USE_ZENDNN
"Build with ZENDNN support"
OFF)
if(USE_ZENDNN AND NOT CPU_INTEL)
message(WARNING
"USE_ZENDNN was requested, but the target processor "
"(${CMAKE_SYSTEM_PROCESSOR}) is not AMD64/x86_64. "
"ZENDNN support will be disabled.")
# Switch it off in the cache so the GUI / subsequent runs see the change
set(USE_ZENDNN OFF CACHE BOOL "Build with ZENDNN support" FORCE)
endif()
cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
@ -1352,7 +1337,6 @@ if(BUILD_SHARED_LIBS)
${PROJECT_SOURCE_DIR}/cmake/public/gflags.cmake
${PROJECT_SOURCE_DIR}/cmake/public/mkl.cmake
${PROJECT_SOURCE_DIR}/cmake/public/mkldnn.cmake
${PROJECT_SOURCE_DIR}/cmake/public/zendnn.cmake
${PROJECT_SOURCE_DIR}/cmake/public/protobuf.cmake
${PROJECT_SOURCE_DIR}/cmake/public/utils.cmake
${PROJECT_SOURCE_DIR}/cmake/public/LoadHIP.cmake

View File

@ -93,7 +93,6 @@ file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.c
file(GLOB native_cpp "native/*.cpp")
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB native_zendnn_cpp "native/zendnn/*.cpp")
file(GLOB vulkan_cpp "vulkan/*.cpp")
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp")
@ -377,7 +376,7 @@ if(BUILD_LITE_INTERPRETER)
append_filelist("aten_native_source_non_codegen_list" all_cpu_cpp)
else()
set(
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_zendnn_cpp}
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
${native_transformers_cpp}

View File

@ -21,4 +21,3 @@
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
#define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@
#define AT_ZENDNN_ENABLED() @AT_ZENDNN_ENABLED@

View File

@ -681,14 +681,6 @@ bool Context::hasEigenSparse() {
#endif
}
bool Context::hasZenDNN() {
#if AT_ZENDNN_ENABLED()
return true;
#else
return false;
#endif
}
at::QEngine Context::qEngine() const {
static auto _quantized_engine = []() {
at::QEngine qengine = at::kNoQEngine;

View File

@ -147,7 +147,6 @@ class TORCH_API Context {
static bool hasMKL();
static bool hasKleidiAI();
static bool hasLAPACK();
static bool hasZenDNN();
static bool hasMKLDNN();
static bool ckSupported();
static bool hasEigenSparse();
@ -627,10 +626,6 @@ inline bool hasEigenSparse() {
return globalContext().hasEigenSparse();
}
inline bool hasZenDNN() {
return globalContext().hasZenDNN();
}
inline bool hasMAGMA() {
return globalContext().hasMAGMA();
}

View File

@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
return;
}
checkTrilTriuMemoryOverlap(result, self);
bool inplace_op = self.is_same(result);
bool inplace_update = false;

View File

@ -1,3 +1,4 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
return std::make_tuple(true, tensor);
}
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
if (result.is_same(self)) {
at::assert_no_internal_overlap(result);
} else {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
}
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
template <bool upper>
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
checkTrilTriuMemoryOverlap(result, self);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Half,

View File

@ -1,5 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <ATen/core/Tensor.h>
#include <ATen/record_function.h>

View File

@ -1162,9 +1162,6 @@ def define_buck_targets(
"--replace",
"@AT_USE_EIGEN_SPARSE@",
"0",
"--replace",
"@AT_ZENDNN_ENABLED@",
"0",
]),
outs = {
"Config.h": ["Config.h"],

View File

@ -1177,7 +1177,6 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/native/ComparisonUtils.cpp",
"aten/src/ATen/native/DispatchStub.cpp",
"aten/src/ATen/native/UpSample.cpp",
"aten/src/ATen/native/zendnn/Matmul.cpp",
"aten/src/ATen/native/mkldnn/BinaryOps.cpp",
"aten/src/ATen/native/mkldnn/Conv.cpp",
"aten/src/ATen/native/mkldnn/ConvPrepack.cpp",

View File

@ -67,5 +67,4 @@
{"USE_CUSPARSELT", "${USE_CUSPARSELT}"}, \
{"USE_XPU", "${USE_XPU}"}, \
{"USE_XCCL", "${USE_XCCL}"}, \
{"USE_ZENDNN", "${USE_ZENDNN}"} \
}

View File

@ -117,10 +117,6 @@ if(@USE_MKLDNN@)
include("${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake")
endif()
if(@USE_ZENDNN@)
include("${CMAKE_CURRENT_LIST_DIR}/public/zendnn.cmake")
endif()
# import targets
include ("${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets.cmake")

View File

@ -162,7 +162,6 @@ set(AT_MKLDNN_ENABLED 0)
set(AT_MKL_ENABLED 0)
set(AT_KLEIDIAI_ENABLED 0)
set(AT_USE_EIGEN_SPARSE 0)
set(AT_ZENDNN_ENABLED 0)
# setting default preferred BLAS options if not already present.
if(NOT INTERN_BUILD_MOBILE)
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
@ -1510,32 +1509,6 @@ if(NOT INTERN_BUILD_MOBILE)
message("disabling MKLDNN because USE_MKLDNN is not set")
endif()
if(USE_ZENDNN)
if(NOT (CMAKE_SYSTEM_NAME MATCHES "Linux"))
message(WARNING
"USE_ZENDNN is currently only supported on Linux. Detected platform: ${CMAKE_SYSTEM_NAME}. Disabling ZenDNN support.")
set(USE_ZENDNN OFF)
elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
message(WARNING
"x64 operating system is required for ZenDNN. "
"ZenDNN codebase will not be compiled."
"Turn this warning off by USE_ZENDNN=OFF.")
set(USE_ZENDNN OFF)
else()
include(${CMAKE_CURRENT_LIST_DIR}/public/zendnn.cmake)
if(ZENDNN_FOUND)
set(AT_ZENDNN_ENABLED 1)
# Add to Caffe2 private dependencies
list(APPEND Caffe2_DEPENDENCY_LIBS zendnnl::zendnnl_archive)
else()
message(WARNING "ZENDNN could not be found.")
caffe2_update_option(USE_ZENDNN OFF)
endif()
endif()
else()
message(STATUS "disabling ZENDNN because USE_ZENDNN is not set")
endif()
if(USE_KLEIDIAI)
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)

View File

@ -1,402 +0,0 @@
include_guard(GLOBAL)
include(ExternalProject)
# declare a zendnnl dependency
macro(zendnnl_add_dependency )
set(options INCLUDE_ONLY)
set(oneValueArgs NAME PATH LIB_SUFFIX INCLUDE_SUFFIX ARCHIVE_FILE ALIAS)
set(multiValueArgs DEPENDS)
cmake_parse_arguments(_zad "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(TOUPPER ${_zad_NAME} _ZAD_UNAME)
if(DEFINED _zad_INCLUDE_SUFFIX)
set(ZENDNNL_${_ZAD_UNAME}_INC_DIR "${_zad_PATH}/${_zad_INCLUDE_SUFFIX}")
else()
set(ZENDNNL_${_ZAD_UNAME}_INC_DIR "${_zad_PATH}/include")
endif()
if(DEFINED _zad_LIB_SUFFIX)
set(ZENDNNL_${_ZAD_UNAME}_LIB_DIR "${_zad_PATH}/${_zad_LIB_SUFFIX}")
else()
set(ZENDNNL_${_ZAD_UNAME}_LIB_DIR "${_zad_PATH}/lib")
endif()
if(NOT EXISTS ${ZENDNNL_${_ZAD_UNAME}_INC_DIR})
file(MAKE_DIRECTORY ${ZENDNNL_${_ZAD_UNAME}_INC_DIR})
endif()
if(${_zad_INCLUDE_ONLY})
add_library(zendnnl_${_zad_NAME}_deps INTERFACE IMPORTED GLOBAL)
#add_dependencies(zendnnl_${_zad_NAME}_deps ${_zad_DEPENDS})
set_target_properties(zendnnl_${_zad_NAME}_deps
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}")
else()
add_library(zendnnl_${_zad_NAME}_deps STATIC IMPORTED GLOBAL)
#add_dependencies(zendnnl_${_zad_NAME}_deps ${_zad_DEPENDS})
set_target_properties(zendnnl_${_zad_NAME}_deps
PROPERTIES
IMPORTED_LOCATION "${ZENDNNL_${_ZAD_UNAME}_LIB_DIR}/${_zad_ARCHIVE_FILE}"
INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}"
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}")
endif()
add_library(${_zad_ALIAS} ALIAS zendnnl_${_zad_NAME}_deps)
list(APPEND ZNL_BYPRODUCTS "${ZENDNNL_${_ZAD_UNAME}_LIB_DIR}/${_zad_ARCHIVE_FILE}")
endmacro()
macro(zendnnl_add_option )
set(options EXECLUDE_FROM_COMMAND_LIST FORCE)
set(oneValueArgs NAME VALUE TYPE CACHE_STRING COMMAND_LIST)
set(multiValueArgs "")
cmake_parse_arguments(_zao "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(${_zao_FORCE})
set(${_zao_NAME} ${_zao_VALUE} CACHE ${_zao_TYPE} ${_zao_CACHE_STRING} FORCE)
else()
set(${_zao_NAME} ${_zao_VALUE} CACHE ${_zao_TYPE} ${_zao_CACHE_STRING})
endif()
if (NOT ${_zao_EXECLUDE_FROM_COMMAND_LIST})
list(APPEND ${_zao_COMMAND_LIST} "-D${_zao_NAME}:${_zao_TYPE}=${_zao_VALUE}")
endif()
endmacro()
message(AUTHOR_WARNING "(ZENDNNL) please ensure all zendnnl variables are set properly.")
if(NOT ZENDNN_FOUND)
# find openmp
find_package(OpenMP REQUIRED QUIET)
# set zendnnl source dir, where zendnnl has been downloaded.
zendnnl_add_option(NAME ZENDNNL_SOURCE_DIR
VALUE ${PROJECT_SOURCE_DIR}/third_party/ZenDNN
TYPE PATH
CACHE_STRING "zendnnl_source_dir"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl binary dir, if unsure set ${CMAKE_CURRENT_BINARY_DIR}/zendnnl.
zendnnl_add_option(NAME ZENDNNL_BINARY_DIR
VALUE ${ZENDNNL_SOURCE_DIR}/build
TYPE PATH
CACHE_STRING "zendnnl_binary_dir"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl install dir, if unsure set ${CMAKE_INSTALL_PREFIX}/zendnnl.
zendnnl_add_option(NAME ZENDNNL_INSTALL_PREFIX
VALUE ${ZENDNNL_BINARY_DIR}/install
TYPE PATH
CACHE_STRING "zendnnl_install_dir"
COMMAND_LIST ZNL_CMAKE_ARGS)
## general zendnnl options
# set ZenDNNL framework build, this should on ON to avoid standalone build.
zendnnl_add_option(NAME ZENDNNL_FWK_BUILD
VALUE ON
TYPE BOOL
CACHE_STRING "zendnnl framework build"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl build option, default is Release.
zendnnl_add_option(NAME ZENDNNL_BUILD_TYPE
VALUE "Release"
TYPE STRING
CACHE_STRING "zendnnl build type"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl log level.
zendnnl_add_option(NAME ZENDNNL_MESSAGE_LOG_LEVEL
VALUE "DEBUG"
TYPE STRING
CACHE_STRING "zendnnl message log level"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl verbose makefile option.
zendnnl_add_option(NAME ZENDNNL_VERBOSE_MAKEFILE
VALUE ON
TYPE BOOL
CACHE_STRING "zendnnl verbose makefile"
COMMAND_LIST ZNL_CMAKE_ARGS)
## components options
# set building zendnnl examples, default os OFF.
zendnnl_add_option(NAME ZENDNNL_BUILD_EXAMPLES
VALUE OFF
TYPE BOOL
CACHE_STRING "build zendnnl examples"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set building zendnnl gtests, default os OFF.
zendnnl_add_option(NAME ZENDNNL_BUILD_GTEST
VALUE OFF
TYPE BOOL
CACHE_STRING "build zendnnl gtests"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set building zendnnl doxygen documentation, default os OFF.
zendnnl_add_option(NAME ZENDNNL_BUILD_DOXYGEN
VALUE OFF
TYPE BOOL
CACHE_STRING "build zendnnl doxygen documentation"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set building zendnnl benchmarking tool, default os OFF.
zendnnl_add_option(NAME ZENDNNL_BUILD_BENCHDNN
VALUE OFF
TYPE BOOL
CACHE_STRING "build zendnnl benchdnn"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set zendnnl code coverage option, default os OFF.
zendnnl_add_option(NAME ZENDNNL_CODE_COVERAGE
VALUE OFF
TYPE BOOL
CACHE_STRING "build zendnnl code coverage"
COMMAND_LIST ZNL_CMAKE_ARGS)
## dependencies
# set if zendnnl depends on amdblis. this should bf OFF only if
# aocldlp dependency is ON.
zendnnl_add_option(NAME ZENDNNL_DEPENDS_AMDBLIS
VALUE OFF
TYPE BOOL
CACHE_STRING "zendnnl amdblis dependency"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set if zendnnl depends on aocldlp. this should bf ON only if
# amdblis dependency is OFF.
zendnnl_add_option(NAME ZENDNNL_DEPENDS_AOCLDLP
VALUE ON
TYPE BOOL
CACHE_STRING "zendnnl aocldlp dependency"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set if zendnnl depends on onednn, default is OFF.
zendnnl_add_option(NAME ZENDNNL_DEPENDS_ONEDNN
VALUE OFF
TYPE BOOL
CACHE_STRING "zendnnl onednn dependency"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set if zendnnl depends on libxsmm, default is OFF.
zendnnl_add_option(NAME ZENDNNL_DEPENDS_LIBXSMM
VALUE ON
TYPE BOOL
CACHE_STRING "zendnnl libxsmm dependency"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set path of amdblis if amdblis is injected. if the framework
# does not inject it, set it to "" (empty string).
zendnnl_add_option(NAME ZENDNNL_AMDBLIS_FWK_DIR
VALUE ""
TYPE PATH
CACHE_STRING "zendnnl amdblis framework path"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set path of aocldlp if aocldlp is injected. if the framework
# does not inject it, set it to "" (empty string).
zendnnl_add_option(NAME ZENDNNL_AOCLDLP_FWK_DIR
VALUE ""
TYPE PATH
CACHE_STRING "zendnnl aocldlp framework path"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set path of onednn if onednn is injected. if the framework
# does not inject it, set it to "" (empty string).
zendnnl_add_option(NAME ZENDNNL_ONEDNN_FWK_DIR
VALUE ""
TYPE PATH
CACHE_STRING "zendnnl onednnn framework path"
COMMAND_LIST ZNL_CMAKE_ARGS)
# set path of libxsmm if libxsmm is injected. if the framework
# does not inject it, set it to "" (empty string).
zendnnl_add_option(NAME ZENDNNL_LIBXSMM_FWK_DIR
VALUE ""
TYPE PATH
CACHE_STRING "zendnnl libxsmm framework path"
COMMAND_LIST ZNL_CMAKE_ARGS)
# try to find pre-built package
set(zendnnl_ROOT "${ZENDNNL_INSTALL_PREFIX}/zendnnl")
set(zendnnl_DIR "${zendnnl_ROOT}/lib/cmake")
find_package(zendnnl QUIET)
if(zendnnl_FOUND)
message(STATUS "(ZENDNNL) ZENDNNL FOUND AT ${zendnnl_ROOT}")
message(STATUS "(ZENDNNL) if zendnnl options are changed from previous build,")
message(STATUS "(ZENDNNL) they will not be reflected")
message(STATUS "(ZENDNNL) If options are changed, please do a clean build.")
if(TARGET zendnnl::zendnnl_archive)
set_target_properties(zendnnl::zendnnl_archive
PROPERTIES IMPORTED_GLOBAL ON)
else()
message(FATAL_ERROR "(ZENDNNL) zendnnl installation does not have imported target zendnnl::zendnnl_archive")
endif()
else()
message(STATUS "(ZENDNNL) ZENDNNL NOT FOUND, will be built as an external project.")
# declare zendnnl library
set(ZENDNNL_LIBRARY_INC_DIR "${ZENDNNL_INSTALL_PREFIX}/zendnnl/include")
set(ZENDNNL_LIBRARY_LIB_DIR "${ZENDNNL_INSTALL_PREFIX}/zendnnl/lib")
if(NOT EXISTS ${ZENDNNL_LIBRARY_INC_DIR})
file(MAKE_DIRECTORY ${ZENDNNL_LIBRARY_INC_DIR})
endif()
add_library(zendnnl_library STATIC IMPORTED GLOBAL)
add_dependencies(zendnnl_library fwk_zendnnl)
set_target_properties(zendnnl_library
PROPERTIES
IMPORTED_LOCATION "${ZENDNNL_LIBRARY_LIB_DIR}/libzendnnl_archive.a"
INCLUDE_DIRECTORIES "${ZENDNNL_LIBRARY_INC_DIR}"
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_LIBRARY_INC_DIR}")
target_link_options(zendnnl_library INTERFACE "-fopenmp")
target_link_libraries(zendnnl_library
INTERFACE OpenMP::OpenMP_CXX
INTERFACE ${CMAKE_DL_LIBS})
add_library(zendnnl::zendnnl_archive ALIAS zendnnl_library)
list(APPEND ZNL_BYPRODUCTS "${ZENDNNL_LIBRARY_LIB_DIR}/libzendnnl_archive.a")
# declare all dependencies
# json dependency
zendnnl_add_dependency(NAME json
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/json"
ALIAS "nlohmann_json::nlohmann_json"
INCLUDE_ONLY)
target_link_libraries(zendnnl_library INTERFACE nlohmann_json::nlohmann_json)
# aoclutils dependency
if (DEFINED ENV{ZENDNNL_MANYLINUX_BUILD})
zendnnl_add_dependency(NAME aoclutils
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
LIB_SUFFIX lib64
ARCHIVE_FILE "libaoclutils.a"
ALIAS "au::aoclutils")
target_link_libraries(zendnnl_library INTERFACE au::aoclutils)
zendnnl_add_dependency(NAME aucpuid
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
LIB_SUFFIX lib64
ARCHIVE_FILE "libau_cpuid.a"
ALIAS "au::au_cpuid")
target_link_libraries(zendnnl_library INTERFACE au::au_cpuid)
else()
zendnnl_add_dependency(NAME aoclutils
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
ARCHIVE_FILE "libaoclutils.a"
ALIAS "au::aoclutils")
target_link_libraries(zendnnl_library INTERFACE au::aoclutils)
zendnnl_add_dependency(NAME aucpuid
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
ARCHIVE_FILE "libau_cpuid.a"
ALIAS "au::au_cpuid")
target_link_libraries(zendnnl_library INTERFACE au::au_cpuid)
endif()
# amdblis dependency
if (ZENDNNL_DEPENDS_AMDBLIS)
zendnnl_add_dependency(NAME amdblis
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/amdblis"
ARCHIVE_FILE "libblis-mt.a"
ALIAS "amdblis::amdblis_archive")
target_link_libraries(zendnnl_library INTERFACE amdblis::amdblis_archive)
endif()
if (ZENDNNL_DEPENDS_AOCLDLP)
zendnnl_add_dependency(NAME aocldlp
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aocldlp"
ARCHIVE_FILE "libaocl-dlp.a"
ALIAS "aocldlp::aocl_dlp_static")
target_link_libraries(zendnnl_library INTERFACE aocldlp::aocl_dlp_static)
endif()
if (ZENDNNL_DEPENDS_ONEDNN)
zendnnl_add_dependency(NAME onednn
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/onednn"
ARCHIVE_FILE "libdnnl.a"
ALIAS "DNNL::dnnl")
target_link_libraries(zendnnl_library INTERFACE DNNL::dnnl)
endif()
# libxsmm dependency
if (ZENDNNL_DEPENDS_LIBXSMM)
zendnnl_add_dependency(NAME libxsmm
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/libxsmm"
ARCHIVE_FILE "libxsmm.a"
ALIAS "libxsmm::libxsmm_archive")
target_link_libraries(zendnnl_library INTERFACE libxsmm::libxsmm_archive)
endif()
message(STATUS "(ZENDNNL) ZNL_BYPRODUCTS=${ZNL_BYPRODUCTS}")
message(STATUS "(ZENDNNL) ZNL_CMAKE_ARGS=${ZNL_CMAKE_ARGS}")
ExternalProject_ADD(fwk_zendnnl
SOURCE_DIR "${ZENDNNL_SOURCE_DIR}"
BINARY_DIR "${ZENDNNL_BINARY_DIR}"
CMAKE_ARGS "${ZNL_CMAKE_ARGS}"
BUILD_COMMAND cmake --build . --target all -j
INSTALL_COMMAND ""
BUILD_BYPRODUCTS ${ZNL_BYPRODUCTS})
list(APPEND ZENDNNL_CLEAN_FILES "${ZENDNNL_BINARY_DIR}")
list(APPEND ZENDNNL_CLEAN_FILES "${ZENDNNL_INSTALL_PREFIX}")
set_target_properties(fwk_zendnnl
PROPERTIES
ADDITIONAL_CLEAN_FILES "${ZENDNNL_CLEAN_FILES}")
# framework dependencies
# add_dependencies(fwk_zendnnl <injected dependency targets>)
get_target_property(FWK_ZENDNNL_DEPENDS fwk_zendnnl MANUALLY_ADDED_DEPENDENCIES)
if(${FWK_ZENDNNL_DEPENDS} STREQUAL "FWK_ZENDNNL_DEPENDS-NOTFOUND")
message(AUTHOR_WARNING "(ZENDNNL) please ensure fwk_zendnnl depends on injected dependencies targets")
else()
message(STATUS "fwk_zendnnl dependencies : ${FWK_ZENDNNL_DEPENDS}")
endif()
# make library and its dependencies depend on fwk_zendnnl
add_dependencies(zendnnl_library fwk_zendnnl)
add_dependencies(zendnnl_json_deps fwk_zendnnl)
add_dependencies(zendnnl_aoclutils_deps fwk_zendnnl)
add_dependencies(zendnnl_aucpuid_deps fwk_zendnnl)
if(ZENDNNL_DEPENDS_AMDBLIS)
add_dependencies(zendnnl_amdblis_deps fwk_zendnnl)
endif()
if(ZENDNNL_DEPENDS_AOCLDLP)
add_dependencies(zendnnl_aocldlp_deps fwk_zendnnl)
endif()
if(ZENDNNL_DEPENDS_ONEDNN)
add_dependencies(zendnnl_onednn_deps fwk_zendnnl)
endif()
if(ZENDNNL_DEPENDS_LIBXSMM)
add_dependencies(zendnnl_libxsmm_deps fwk_zendnnl)
endif()
endif()
set(ZENDNN_FOUND TRUE)
endif(NOT ZENDNN_FOUND)

View File

@ -148,7 +148,6 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_PYTORCH_METAL_EXPORT : ${USE_PYTORCH_METAL_EXPORT}")
message(STATUS " USE_MPS : ${USE_MPS}")
message(STATUS " CAN_COMPILE_METAL : ${CAN_COMPILE_METAL}")
message(STATUS " USE_ZENDNN : ${USE_ZENDNN}")
message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}")
if(${CAFFE2_USE_MKL})
message(STATUS " USE_STATIC_MKL : ${USE_STATIC_MKL}")

View File

@ -1,8 +0,0 @@
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/ZenDNN)
message(WARNING "(ZENDNNL) Library not found at ${PROJECT_SOURCE_DIR}/third_party/ZenDNN")
else()
find_package(ZENDNN QUIET)
if(ZENDNN_FOUND)
message(STATUS, "(ZENDNN) ZenDNN library was built successfully.")
endif(ZENDNN_FOUND)
endif()

View File

@ -67,9 +67,6 @@
# USE_NUMPY=0
# disables the NumPy build
#
# USE_ZENDNN=0
# disables the ZenDNN build
#
# BUILD_TEST=0
# disables the test build
#
@ -1193,10 +1190,6 @@ class build_ext(setuptools.command.build_ext.build_ext):
report("-- Not using CBLAS in MKLDNN")
else:
report("-- Not using MKLDNN")
if cmake_cache_vars["USE_ZENDNN"]:
report("-- Using ZENDNN")
else:
report("-- Not using ZENDNN")
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
report(
"-- Using system provided NCCL library at "

View File

@ -424,7 +424,7 @@ from user code:
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
optree.tree_flatten_with_path(d)
return torch.sin(x)
fn(torch.randn(4))
@ -434,10 +434,10 @@ from user code:
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree
pytree_modules["cxx"] = cxx_pytree
pytree_modules["native_optree"] = cxx_pytree.optree
else:
cxx_pytree = None
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
torch.ones(3, 2),
1,
]
new_tree = pytree.tree_unflatten(new_leaves, treespec)
if pytree.__name__ == "optree":
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
new_leaves.pop()
# The treespec argument comes first in OpTree / JAX PyTree
new_tree = pytree.tree_unflatten(treespec, new_leaves)
else:
new_tree = pytree.tree_unflatten(new_leaves, treespec)
return leaves, new_tree
x = torch.randn(3, 2)
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):
# OpTree and JAX PyTree do not have `tree_map_only`
return
def fn(xs):
def mapper(x):
return x.clone()

View File

@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(result_triu_min, expected_triu_min)
self.assertEqual(result_tril_min, expected_tril_min)
@dtypes(torch.float)
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
base = torch.rand((), dtype=dtype, device=device)
expanded = base.expand(3, 3)
msg = (
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation."
)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.triu_(1)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.tril_(-1)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

View File

@ -140,7 +140,6 @@ class TestPublicBindings(TestCase):
"has_mps",
"has_openmp",
"has_spectral",
"has_zendnn",
"iinfo",
"import_ir_module_from_buffer",
"import_ir_module",

1
third_party/ZenDNN vendored

Submodule third_party/ZenDNN deleted from af92954683

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)):
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
return True
return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
return list(
tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
)
__all__ += ["tree_leaves"]
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
none_is_leaf: bool
namespace: str
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None:
if self._type is None:
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
]
if (
treespec.type in BUILTIN_TYPES
or (treespec.type is type(None) and not self.none_is_leaf)
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
f"[{', '.join(children_representations)}])"
)
return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)
inner = [
str(helper(self)),
*(["NoneIsLeaf"] if self.none_is_leaf else []),
f"namespace={self.namespace!r}",
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int:
return self.num_leaves
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if len(children) != treespec.num_children:
raise ValueError(
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if (
node_type
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
leaves.append(node)
return _LEAF_SPEC
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
(
children,
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
return PyTreeSpec(
subspecs,
type(node),
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
optree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)[1]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None

View File

@ -264,7 +264,8 @@ class OutputComparisonLogger(OutputLogger):
# fmt: on
if not self.enabled:
return x
assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
if not isinstance(x, torch.Tensor):
raise AssertionError("non-tensor inputs not yet supported")
if self.save_activations:
# save the activation, for debugging
self.stats.append(x.detach())
@ -595,9 +596,8 @@ def _extract_logger_info_one_model(
key = mod.ref_name
if key not in results:
results[key] = {}
assert mod.model_name not in results[key], (
f"{mod.model_name} is already present in results"
)
if mod.model_name in results[key]:
raise AssertionError(f"{mod.model_name} is already present in results")
if mod.results_type not in results[key]:
results[key][mod.results_type] = {}
if mod.model_name not in results[key][mod.results_type]:
@ -809,12 +809,10 @@ def extend_logger_results_with_comparison(
"""
for results_type_to_results in results.values():
for model_name_to_results in results_type_to_results.values():
assert model_name_1 in model_name_to_results, (
f"{model_name_1} not found in results"
)
assert model_name_2 in model_name_to_results, (
f"{model_name_2} not found in results"
)
if model_name_1 not in model_name_to_results:
raise AssertionError(f"{model_name_1} not found in results")
if model_name_2 not in model_name_to_results:
raise AssertionError(f"{model_name_2} not found in results")
results_1 = model_name_to_results[model_name_1]
results_2 = model_name_to_results[model_name_2]
@ -832,7 +830,8 @@ def extend_logger_results_with_comparison(
):
result_1 = cur_result_1
break
assert result_1 is not None
if result_1 is None:
raise AssertionError("Expected result_1 to be not None")
values_1 = result_1["values"]
values_2 = result_2["values"]

View File

@ -150,7 +150,8 @@ class _NSGraphMatchableSubgraphsIterator:
if node.op == "call_function":
return node.target not in self.non_matchable_functions
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
target_mod = getattr_from_fqn(self.gm, node.target)
return not any(
isinstance(target_mod, t) # type: ignore[arg-type]
@ -228,16 +229,19 @@ def _get_subgraph_relationship_type(
else:
return SubgraphTypeRelationship.NOT_RELATED
elif node_a.op == "call_module":
assert (
subgraph_a.base_op_node == subgraph_a.start_node
and subgraph_b.base_op_node == subgraph_b.start_node
), (
"Matching call_module patterns where base_op_node != start_node is not supported yet"
)
if (
subgraph_a.base_op_node != subgraph_a.start_node
or subgraph_b.base_op_node != subgraph_b.start_node
):
raise AssertionError(
"Matching call_module patterns where base_op_node != start_node is not supported yet"
)
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
if not isinstance(node_a.target, str):
raise AssertionError(f"Expected str, got {type(node_a.target)}")
mod_a = getattr_from_fqn(gm_a, node_a.target)
assert isinstance(node_b.target, str)
if not isinstance(node_b.target, str):
raise AssertionError(f"Expected str, got {type(node_b.target)}")
mod_b = getattr_from_fqn(gm_b, node_b.target)
key = (type(mod_a), type(mod_b))
@ -312,7 +316,8 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetT
if node.op in ("call_function", "call_method"):
return node.target
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
return type(mod)
return None
@ -452,9 +457,10 @@ of subgraphs, and each pair of subgraphs is related to each other."""
key_name_b = _get_name_for_subgraph(
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
)
assert key_name_a == key_name_b, (
f"Subgraph names {key_name_a} and {key_name_b} do not match"
)
if key_name_a != key_name_b:
raise AssertionError(
f"Subgraph names {key_name_a} and {key_name_b} do not match"
)
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
continue
elif cur_subgraph_a is None and cur_subgraph_b is None:

View File

@ -32,7 +32,8 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
# an observer, get the fqn of the node being observed.
node_to_use_for_fqn = node
if node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
module = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
@ -348,7 +349,8 @@ def _insert_dtype_cast_after_node(
new_dtype_cast_name,
)
else:
assert dtype_cast_mod_cls
if not dtype_cast_mod_cls:
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
return graph_c.create_node(
@ -373,7 +375,8 @@ def _insert_dtype_cast_after_node(
)
results.append(new_dtype_cast_node)
else:
assert dtype_cast_mod_cls
if not dtype_cast_mod_cls:
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
new_dtype_cast_node = graph_c.create_node(
@ -412,10 +415,8 @@ def _copy_node_from_a_to_c(
)
return node_a_copy
elif node_a.op == "call_method":
assert node_a.target in (
"dequantize",
"to",
), f"target {node_a.target} is not implemented"
if node_a.target not in ("dequantize", "to"):
raise AssertionError(f"target {node_a.target} is not implemented")
if node_a.target == "dequantize":
arg_copy = _copy_node_from_a_to_c(
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
@ -535,7 +536,8 @@ def _insert_copy_of_subgraph_a_after_input_node_c(
"""
TODO(before land): real docblock
"""
assert isinstance(input_node_c, (Node, list))
if not isinstance(input_node_c, (Node, list)):
raise AssertionError(f"Expected Node or list, got {type(input_node_c)}")
# create a sequential list of the subgraphs' nodes from start to end,
# because we need to add the nodes to graph C in non-reverse order
@ -621,7 +623,8 @@ def _insert_copy_of_node_a_after_input_node_c(
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
assert isinstance(input_node_c, list)
if not isinstance(input_node_c, list):
raise AssertionError(f"Expected list, got {type(input_node_c)}")
graph_c = input_node_c[0].graph
norm_args_kwargs = node_a.normalized_arguments(
@ -645,9 +648,10 @@ def _insert_copy_of_node_a_after_input_node_c(
return arg
elif isinstance(kwarg_val, (list, tuple)):
for el in kwarg_val:
assert not isinstance(el, Node), (
"handling of Node inside list is not implemented"
)
if isinstance(el, Node):
raise AssertionError(
"handling of Node inside list is not implemented"
)
return arg
else:
raise AssertionError(
@ -684,7 +688,8 @@ def _insert_copy_of_node_a_after_input_node_c(
# if target is a module, we point to the module from gm_b
new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
# fetch the corresponding module from gm_a
assert isinstance(node_a.target, str)
if not isinstance(node_a.target, str):
raise AssertionError(f"Expected str, got {type(node_a.target)}")
mod_a = getattr_from_fqn(gm_a, node_a.target)
setattr(gm_b, new_mod_copy_name, mod_a)
node_a_shadows_c = graph_c.create_node(
@ -696,7 +701,8 @@ def _insert_copy_of_node_a_after_input_node_c(
)
return node_a_shadows_c
else:
assert node_a.op in ("call_function", "call_method")
if node_a.op not in ("call_function", "call_method"):
raise AssertionError(f"Unexpected op: {node_a.op}")
node_a_shadows_c = graph_c.create_node(
node_a.op,
node_a.target,
@ -791,7 +797,8 @@ def create_a_shadows_b(
ref_node_type_b,
) = start_node_b_to_matched_subgraph_a_and_name[node_b]
else:
assert node_b_is_end_node
if not node_b_is_end_node:
raise AssertionError("Expected node_b_is_end_node to be not false")
(
subgraph_a,
ref_name,
@ -1001,7 +1008,10 @@ def create_a_shadows_b(
)
input_logger: Union[Node, list[Node]] = dtype_cast_node
else:
assert isinstance(dtype_cast_node, list)
if not isinstance(dtype_cast_node, list):
raise AssertionError(
f"Expected list, got {type(dtype_cast_node)}"
)
new_loggers = []
for dtype_cast_idx, dtype_cast_node_inner in enumerate(
dtype_cast_node
@ -1083,7 +1093,10 @@ def create_a_shadows_b(
input_logger_mod.ref_node_name = cur_node.name
else:
# pyrefly: ignore # unbound-name
assert isinstance(input_logger, list)
if not isinstance(input_logger, list):
raise AssertionError(
f"Expected list, got {type(input_logger)}"
)
# pyrefly: ignore # unbound-name
for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name)

View File

@ -144,9 +144,11 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
seen_nodes.add(node_or_tuple)
else:
assert isinstance(node_or_tuple, tuple)
if not isinstance(node_or_tuple, tuple):
raise AssertionError(f"Expected tuple, got {type(node_or_tuple)}")
for node in node_or_tuple:
assert isinstance(node, Node)
if not isinstance(node, Node):
raise AssertionError(f"Expected Node, got {type(node)}")
if node in seen_nodes:
was_seen = True
seen_nodes.add(node)
@ -160,7 +162,10 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
if len(cur_match[1]) == 1:
list_of_nodes = cur_match[1]
else:
assert len(cur_match[1]) == 2
if len(cur_match[1]) != 2:
raise ValueError(
f"Expected cur_match[1] to have length 2, got {len(cur_match[1])}"
)
# either (a, b), or ((a, b), c) or (c, (a, b))
# cannot make any assumptions on order, not clear what the
# _find_matches function is doing to populate this
@ -181,13 +186,12 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
last_node = n
else:
mid_node = n
assert (
first_node is not None
and mid_node is not None
and last_node is not None
)
assert mid_node.args[0] is first_node
assert last_node.args[0] is mid_node
if first_node is None or mid_node is None or last_node is None:
raise AssertionError("Expected all nodes to be non-None")
if mid_node.args[0] is not first_node:
raise AssertionError("Expected mid_node.args[0] to be first_node")
if last_node.args[0] is not mid_node:
raise AssertionError("Expected last_node.args[0] to be mid_node")
return [last_node, mid_node, first_node]
if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
@ -377,7 +381,10 @@ def create_submodule_from_subgraph(
# the current implementation is simplistic and cannot handle
# ops with two or more arguments which need to be passed from
# the previous op, so we assert them out
assert cur_node_orig.target not in BINARY_FUNCTIONS
if cur_node_orig.target in BINARY_FUNCTIONS:
raise AssertionError(
f"Unexpected binary function target: {cur_node_orig.target}"
)
# at this point in the code, cur_node_copy is pointing to the copy
# of the previous node
@ -435,9 +442,10 @@ def create_submodule_from_subgraph(
break
# go to next node
assert len(cur_node_orig.users.keys()) == 1, (
f"{cur_node_orig} has more than 1 users, not supported yet"
)
if len(cur_node_orig.users.keys()) != 1:
raise AssertionError(
f"{cur_node_orig} has more than 1 users, not supported yet"
)
cur_node_orig = next(iter(cur_node_orig.users.keys()))
cur_iteration += 1
if cur_iteration > iteration_limit:
@ -494,7 +502,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
)
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, logger_mod_orig)
with mt.graph.inserting_after(last_node):
new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
@ -537,9 +546,10 @@ def create_one_transformed_and_logged_copy_of_subgraph(
"prepare_custom_config",
"qconfig_mapping",
]:
assert kwarg_name not in custom_prepare_kwargs, (
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
)
if kwarg_name in custom_prepare_kwargs:
raise AssertionError(
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
)
prepare_kwargs: dict[str, Any] = {
"example_inputs": example_inputs,
"qconfig_mapping": qconfig_mapping,
@ -551,7 +561,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
# attach the wrapper to the model
attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, orig_mod_copy_wrapped)
# add a call to the wrapper module from the parent graph
@ -600,7 +611,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
)
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, logger_mod_orig)
with mt.graph.inserting_after(new_node):
logger = mt.graph.call_module(
@ -824,7 +836,8 @@ def create_add_loggers_graph(
):
new_shadow_mod = maybe_shadow_mod
break
assert new_shadow_mod is not None
if new_shadow_mod is None:
raise AssertionError("Expected new_shadow_mod to be non-None")
orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
@ -850,7 +863,10 @@ def create_add_loggers_graph(
fqn,
)
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
assert not hasattr(model, attr_name)
if hasattr(model, attr_name):
raise AssertionError(
f"Unexpected attribute '{attr_name}' found in {model}"
)
setattr(model, attr_name, logger_mod_orig)
insertion_point = last_node
with model.graph.inserting_after(insertion_point):
@ -887,9 +903,15 @@ def create_add_loggers_graph(
# since now only linear subgraphs are supported, all nodes
# except the last one must have only one user
if cur_node_orig != last_node:
assert len(cur_node_orig.users.keys()) == 1
if len(cur_node_orig.users.keys()) != 1:
raise AssertionError(
f"Expected exactly 1, but got {len(cur_node_orig.users)}"
)
cur_node_orig = next(iter(cur_node_orig.users.keys()))
assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
if cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX):
raise AssertionError(
"cur_node_orig should not start with SHADOW_NODE_NAME_PREFIX"
)
insertion_point = cur_node_copy
# add a comparison logger after last_node's copy
@ -905,7 +927,10 @@ def create_add_loggers_graph(
fqn,
)
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
assert not hasattr(model, attr_name)
if hasattr(model, attr_name):
raise AssertionError(
f"Unexpected attribute '{attr_name}' found in {model}"
)
setattr(model, attr_name, logger_mod_orig)
with model.graph.inserting_after(insertion_point):
logger = model.graph.call_module(
@ -979,7 +1004,8 @@ def create_add_loggers_graph(
return prev_shadow_output
cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
assert cur_shadow_input is not None
if cur_shadow_input is None:
raise AssertionError("Expected cur_shadow_input to be non-None")
cur_shadow_input.args = tree_map(
maybe_remap_node_to_shadow, cur_shadow_input.args
)
@ -1019,7 +1045,8 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
# we have `w2_0`, and are navigating this subgraph
# to get `_input_scale_1` and `_input_zero_point_1`
assert len(shadow_n.users) == 1
if len(shadow_n.users) != 1:
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
quant_node = next(iter(shadow_n.users.keys()))
new_args: Any = None
if quant_node.target == torch.quantize_per_channel:
@ -1028,7 +1055,10 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
new_args = (scale_val, zp_val, axis, dtype)
else:
assert quant_node.target == torch.quantize_per_tensor
if quant_node.target != torch.quantize_per_tensor:
raise AssertionError(
f"Expected torch.quantize_per_tensor, but got {quant_node.target}"
)
_weight, scale_node, zp_node, dtype = quant_node.args
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)

View File

@ -167,7 +167,8 @@ def end_node_matches_reversed_fusion(
elif cur_node.op == "call_module":
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
if not isinstance(cur_node.target, str):
raise AssertionError(f"Expected str, got {type(cur_node.target)}")
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
@ -190,7 +191,10 @@ def end_node_matches_reversed_fusion(
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if not isinstance(cur_fusion_el, tuple):
raise AssertionError(
f"Expected tuple, got {type(cur_fusion_el)}"
)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:

View File

@ -61,7 +61,8 @@ def get_node_first_input_and_output_type(
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -73,8 +74,11 @@ def get_node_first_input_and_output_type(
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_module":
assert node.op == "call_module"
assert isinstance(node.target, str)
if node.op != "call_module":
raise AssertionError(f"Expected call_module, got '{node.op}'")
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, but got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
is_known_fp32_or_int8_input_module = any(
isinstance(mod, target_type) # type: ignore[arg-type]
@ -87,7 +91,8 @@ def get_node_first_input_and_output_type(
# A logger or observer's input and output type is the output
# type of the preceding node.
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -116,7 +121,8 @@ def get_node_first_input_and_output_type(
# So, we look up the output type of the previous node and return that
# as the input type of this node instance.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
if not isinstance(prev_node, Node):
raise AssertionError(f"Expected Node, got {type(prev_node)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -131,7 +137,8 @@ def get_node_first_input_and_output_type(
# as the input type of this node instance. We also look up the target
# of to and return the correct output type.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
if not isinstance(prev_node, Node):
raise AssertionError(f"Expected Node, got {type(prev_node)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -140,15 +147,17 @@ def get_node_first_input_and_output_type(
)
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
assert cur_node_dtype_target is torch.float16, (
f"{cur_node_dtype_target} handling needs to be added"
)
if cur_node_dtype_target is not torch.float16:
raise AssertionError(
f"{cur_node_dtype_target} handling needs to be added"
)
return (prev_node_output_type, NodeInputOrOutputType.FP16)
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -181,8 +190,14 @@ def get_node_input_qparams(
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
if not isinstance(scale_node, Node):
raise AssertionError(f"Expected Node, got {type(scale_node)}")
if not isinstance(scale_node.target, str):
raise AssertionError(f"Expected str, got {type(scale_node.target)}")
if not isinstance(zp_node, Node):
raise AssertionError(f"Expected Node, got {type(zp_node)}")
if not isinstance(zp_node.target, str):
raise AssertionError(f"Expected str, got {type(zp_node.target)}")
scale_obj = getattr_from_fqn(gm, scale_node.target)
zp_obj = getattr_from_fqn(gm, zp_node.target)
return (scale_obj, zp_obj)
@ -200,7 +215,8 @@ def get_node_input_qparams(
elif prev_node.op == "call_module":
# get type of the module
assert isinstance(prev_node.target, str)
if not isinstance(prev_node.target, str):
raise AssertionError(f"Expected str, got {type(prev_node.target)}")
module_obj = getattr_from_fqn(gm, prev_node.target)
if isinstance(
module_obj,
@ -259,15 +275,24 @@ def return_first_non_observer_node(
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
if len(node.args) != 1:
raise AssertionError(
f"Expected node.args to have length 1, got {len(node.args)}"
)
if not isinstance(node.args[0], Node):
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
node_obj = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
if len(node.args) != 1:
raise AssertionError(
f"Expected node.args to have length 1, got {len(node.args)}"
)
if not isinstance(node.args[0], Node):
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
node = node.args[0]
return node
@ -331,7 +356,8 @@ def get_target_type_str(node: Node, gm: GraphModule) -> str:
if node.op in ("call_function", "call_method"):
target_type = torch.typename(node.target)
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
target_mod = getattr_from_fqn(gm, node.target)
target_type = torch.typename(target_mod)
return target_type
@ -365,7 +391,8 @@ def rekey_logger_info_on_node_name_of_model(
for model_name_to_results in result_type_to_results.values():
for cur_model_name, list_of_results in model_name_to_results.items():
if cur_model_name == model_name:
assert len(list_of_results)
if len(list_of_results) == 0:
raise AssertionError("Expected list_of_results to be not empty")
new_layer_name = list_of_results[0]["ref_node_name"]
else:
continue
@ -519,14 +546,20 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
)
if norm_args_and_kwargs is not None:
norm_args, norm_kwargs = norm_args_and_kwargs
assert len(norm_args) + len(norm_kwargs) > idx
if len(norm_args) + len(norm_kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(norm_args) + len(norm_kwargs)}"
)
if idx < len(norm_args):
return norm_args[idx]
else:
# note: in Python 3.7+ dicts are ordered
return list(norm_kwargs.values())[idx]
else:
assert len(node.args) + len(node.kwargs) > idx
if len(node.args) + len(node.kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
)
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
@ -536,7 +569,10 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
# this RuntimeError happens when node argument normalization
# requires typehints to proceed, such as for torch.add where
# either the first, second or both arguments could be tensors
assert len(node.args) + len(node.kwargs) > idx
if len(node.args) + len(node.kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
) from None
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:

View File

@ -77,7 +77,8 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
res.append(param_value)
return res
else:
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
if not isinstance(mod, nnqd.LSTM):
raise AssertionError(f"type {type(mod)} not handled yet")
res = []
for weight_value in mod._all_weight_values:
res.append(
@ -92,10 +93,13 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# traverse backwards from the weight arg, accounting for any observers
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
if not isinstance(weight_arg_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
weight_node = return_first_non_observer_node(weight_arg_node, gm)
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
@ -103,8 +107,10 @@ def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# qconv state is arg 1
qconv_state_node = node.args[1]
assert isinstance(qconv_state_node, Node)
assert qconv_state_node.op == "get_attr"
if not isinstance(qconv_state_node, Node):
raise AssertionError(f"Expected Node, got {type(qconv_state_node)}")
if qconv_state_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {qconv_state_node.op}")
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
return qconv_state_obj.weight()
@ -115,34 +121,44 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# weight -> obs -> linear
# weight -> to(torch.float16) -> dequantize -> linear
linear_second_arg = node.args[1]
assert isinstance(linear_second_arg, Node)
if not isinstance(linear_second_arg, Node):
raise AssertionError(f"Expected Node, got {type(linear_second_arg)}")
if linear_second_arg.op == "call_module":
# weight -> obs -> linear
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
if not isinstance(weight_arg_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
weight_node = weight_arg_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
elif linear_second_arg.op == "call_method":
# weight -> to(torch.float16) -> dequantize -> linear
assert linear_second_arg.op == "call_method"
if linear_second_arg.op != "call_method":
raise AssertionError(f"Expected call_method, got {linear_second_arg.op}")
dequant_node = node.args[1]
assert isinstance(dequant_node, Node)
if not isinstance(dequant_node, Node):
raise AssertionError(f"Expected Node, got {type(dequant_node)}")
to_fp16_node = dequant_node.args[0]
assert isinstance(to_fp16_node, Node)
if not isinstance(to_fp16_node, Node):
raise AssertionError(f"Expected Node, got {type(to_fp16_node)}")
# extract the dtype, so we can cast to it before returning
target_dtype = to_fp16_node.args[1]
weight_node = to_fp16_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
# return the weight with fp16 cast
return weight.detach().to(target_dtype)
else:
assert linear_second_arg.op == "get_attr"
if linear_second_arg.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {linear_second_arg.op}")
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
return weight.detach()
@ -150,8 +166,10 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# packed weight is arg 1
packed_weight_node = node.args[1]
assert isinstance(packed_weight_node, Node)
assert packed_weight_node.op == "get_attr"
if not isinstance(packed_weight_node, Node):
raise AssertionError(f"Expected Node, got {type(packed_weight_node)}")
if packed_weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {packed_weight_node.op}")
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
# TODO(future PR): why does packed_weight.unpack() not work?
(weight, _bias), _name = packed_weight.__getstate__()
@ -264,7 +282,8 @@ def extract_weight_from_node(
elif node.op == "call_module":
# for call_module, we need to look up the modules to do the type check
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
for target_mod_type, weight_extraction_fn in module_mapping.items():

View File

@ -2268,8 +2268,6 @@ Call this whenever a new thread is created in order to propagate values from
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
ASSERT_TRUE(set_module_attr(
"_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False));
ASSERT_TRUE(
set_module_attr("has_zendnn", at::hasZenDNN() ? Py_True : Py_False));
py_module.def("_valgrind_supported_platform", []() {
#if defined(USE_VALGRIND)

View File

@ -3259,7 +3259,7 @@ struct to_ir {
case TK_IN:
return aten::__contains__;
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -3306,7 +3306,7 @@ struct to_ir {
case TK_RSHIFT:
return "__rshift__";
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -4120,8 +4120,7 @@ struct to_ir {
} else if (kind == aten::ge) {
return aten::le;
}
throw std::runtime_error(
"reverseComparision: unsupported NodeKind. File a bug");
TORCH_CHECK(false, "reverseComparision: unsupported NodeKind. File a bug");
}
// any expression that can produce a SugaredValue is handled here

View File

@ -94,7 +94,7 @@ C10_EXPORT std::string kindToString(int kind) {
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
TORCH_CHECK(false, "Unknown kind: ", kind);
}
}

View File

@ -167,12 +167,12 @@ Value* TracingState::getValue(const IValue& var) {
// Didn't find it. Bake in a constant
if (ten.requires_grad()) {
pauseTracing();
std::ostringstream oss;
oss << "Cannot insert a Tensor that requires grad as a constant. "
<< "Consider making it a parameter or input, or detaching the gradient\n"
<< "Tensor:\n"
<< ten;
throw std::runtime_error(oss.str());
TORCH_CHECK(
false,
"Cannot insert a Tensor that requires grad as a constant. ",
"Consider making it a parameter or input, or detaching the gradient\n",
"Tensor:\n",
ten);
}
Value* constant = graph->insertConstant(ten);
@ -208,15 +208,19 @@ Value* TracingState::getValue(const IValue& var) {
}
}
std::ostringstream oss;
if (var.isFuture()) {
oss << "Tried to trace Future or Object that the tracer was not aware of.";
TORCH_CHECK(
false,
"Tried to trace Future or Object that the tracer was not aware of.");
} else {
oss << "Tried to trace " << var
<< " but it is not part of the active trace. Modules that are called during a trace"
<< " must be registered as submodules of the thing being traced.";
TORCH_CHECK(
false,
"Tried to trace ",
var,
" but it is not part of the active trace. Modules that are called during a trace",
" must be registered as submodules of the thing being traced.");
}
throw std::runtime_error(oss.str());
} else {
// If the values are non-tensors, we try to create constants
// and bake those constants into the traced graph
@ -225,11 +229,12 @@ Value* TracingState::getValue(const IValue& var) {
recordSourceLocation(constant.value()->node());
return *constant;
}
std::ostringstream os;
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
<< "The below value could not be materialized as a constant:\n"
<< var;
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot get value trace for type ",
var.tagKind(),
". The below value could not be materialized as a constant:\n",
var);
}
}
bool TracingState::hasValue(const IValue& var) const {
@ -252,15 +257,14 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
auto& value_map = getTracingState()->env_stack.back();
auto it = value_map.find(iv);
if (it == value_map.end()) {
std::ostringstream os;
os << "output " << i << " (" << var
<< ") of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your "
"program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
it != value_map.end(),
"output ",
i,
" (",
var,
") of traced region did not have observable data dependence with trace inputs; ",
"this probably indicates your program cannot be understood by the tracer.");
return it->second;
} else if (iv.isTensorList()) {
if (tracing_mode_strict) {
@ -281,11 +285,10 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
graph->insertNode(tuple_node);
return tuple_node->output();
} else if (iv.isGenericDict()) {
if (tracing_mode_strict) {
throw std::runtime_error(
"Encountering a dict at the output of the tracer" +
std::string(STRICT_TRACER_MSG));
}
TORCH_CHECK(
!tracing_mode_strict,
"Encountering a dict at the output of the tracer",
STRICT_TRACER_MSG);
auto dict = iv.toGenericDict();
TypePtr key_type = dict.keyType();
TypePtr value_type = dict.valueType();
@ -304,15 +307,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
}
}
}
if (!key_type_valid || !value_type_valid) {
std::ostringstream os;
os << "output " << i << " (" << dict << ") of traced region "
<< "cannot be understood by the tracer, only outputs matching"
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
<< "can be a dictionary output of a traced function";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
key_type_valid && value_type_valid,
"output ",
i,
" (",
dict,
") of traced region cannot be understood by the tracer, only outputs matching ",
"dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] ",
"can be a dictionary output of a traced function");
std::vector<Value*> keys;
std::vector<Value*> values;
for (const auto& entry : dict) {
@ -598,10 +601,11 @@ void TracingState::setValue(const IValue& v, Value* value) {
setValue(entry.value(), static_value);
}
} else {
std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
<< "Supported types are tensor, tensor list, and tuple of tensors.";
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot set value trace for type ",
v.tagKind(),
". Supported types are tensor, tensor list, and tuple of tensors.");
}
}
@ -801,11 +805,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
recordSourceLocation(info[i]->node());
}
for (jit::Value* v : info) {
if (*v->type() != *jit::IntType::get()) {
throw std::runtime_error(
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
TORCH_CHECK(
*v->type() == *jit::IntType::get(),
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
n->addInput(
g->insertNode(g->createList(jit::IntType::get(), info))->output());

View File

@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/frontend/lexer.h>
@ -37,10 +38,10 @@ struct Tree : c10::intrusive_ptr_target {
return true;
}
virtual const SourceRange& range() const {
throw std::runtime_error("is an Atom");
TORCH_CHECK(false, "is an Atom");
}
virtual const std::string& stringValue() const {
throw std::runtime_error("stringValue can only be called on TK_STRING");
TORCH_CHECK(false, "stringValue can only be called on TK_STRING");
}
virtual const TreeList& trees() const {
static const TreeList empty_trees = {};
@ -79,13 +80,16 @@ struct Tree : c10::intrusive_ptr_target {
int lineno,
size_t expected_subtrees,
bool allow_more) const {
if (kind() != k) {
std::stringstream ss;
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
<< "' but found '" << kindToString(kind()) << "'\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
TORCH_CHECK(
kind() == k,
filename,
":",
lineno,
": expecting kind '",
kindToString(k),
"' but found '",
kindToString(kind()),
"'\n");
if (trees().size() < expected_subtrees ||
(!allow_more && trees().size() != expected_subtrees)) {
std::stringstream ss;
@ -93,7 +97,7 @@ struct Tree : c10::intrusive_ptr_target {
<< expected_subtrees << " subtrees, but found only " << trees().size()
<< "\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
TORCH_CHECK(false, ss.str());
}
}
~Tree() override = default;

View File

@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]