mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 08:34:52 +08:00
Compare commits
5 Commits
gh/naveent
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
| e3d00beddd | |||
| 21131a2444 | |||
| 1009790ad8 | |||
| 410e6a4321 | |||
| 23c55c5b66 |
@ -86,10 +86,6 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *zen* ]]; then
|
||||
export USE_ZENDNN=1
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
export USE_MKLDNN_ACL=1
|
||||
|
||||
11
.github/actionlint.yaml
vendored
11
.github/actionlint.yaml
vendored
@ -54,17 +54,12 @@ self-hosted-runner:
|
||||
- windows-11-arm64
|
||||
- windows-11-arm64-preview
|
||||
# Organization-wide AMD-hosted runners
|
||||
# MI2xx non-ARC runners
|
||||
# MI2xx runners
|
||||
- linux.rocm.gpu
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.2
|
||||
- linux.rocm.gpu.4
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.gfx1100
|
||||
# MI2xx ARC runners
|
||||
- linux.rocm.gpu.mi250.1
|
||||
- linux.rocm.gpu.mi250.2
|
||||
- linux.rocm.gpu.mi250.4
|
||||
# gfx942 ARC runners
|
||||
# gfx942 runners
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
|
||||
@ -80,7 +80,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-zen-py3.10-gcc11-build
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -106,7 +106,7 @@ jobs:
|
||||
needs: inductor-build
|
||||
if: github.event.schedule == '0 7 * * *'
|
||||
with:
|
||||
build-environment: linux-jammy-zen-py3.10-gcc11-build
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true
|
||||
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
|
||||
@ -122,7 +122,7 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-build
|
||||
with:
|
||||
build-environment: linux-jammy-zen-py3.10-gcc11-build
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
dashboard-tag: training-${{ inputs.training || 'false' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'true' }}-aotinductor-${{ inputs.aotinductor || 'true' }}-freezing-${{ inputs.freezing || 'true' }}
|
||||
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
|
||||
|
||||
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -36,12 +36,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -132,6 +132,3 @@
|
||||
[submodule "third_party/aiter"]
|
||||
path = third_party/aiter
|
||||
url = https://github.com/ROCm/aiter.git
|
||||
[submodule "third_party/ZenDNN"]
|
||||
path = third_party/ZenDNN
|
||||
url = https://github.com/amd/ZenDNN.git
|
||||
|
||||
@ -82,7 +82,6 @@ include_patterns = [
|
||||
'aten/src/ATen/native/mkldnn/xpu/**/*.cpp',
|
||||
'aten/src/ATen/native/Tensor*.h',
|
||||
'aten/src/ATen/native/Tensor*.cpp',
|
||||
'aten/src/ATen/native/zendnn/*.*',
|
||||
'c10/**/*.h',
|
||||
'c10/**/*.cpp',
|
||||
'torch/csrc/**/*.h',
|
||||
|
||||
@ -205,11 +205,6 @@ filegroup(
|
||||
srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_zendnn_cpp",
|
||||
srcs = glob(["aten/src/ATen/native/zendnn/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_base_vulkan",
|
||||
srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
|
||||
@ -290,7 +285,6 @@ header_template_rule(
|
||||
"@AT_BLAS_USE_CBLAS_DOT@": "1",
|
||||
"@AT_KLEIDIAI_ENABLED@": "0",
|
||||
"@AT_USE_EIGEN_SPARSE@": "0",
|
||||
"@AT_ZENDNN_ENABLED@": "0",
|
||||
},
|
||||
)
|
||||
|
||||
@ -371,7 +365,6 @@ cc_library(
|
||||
":aten_native_sparse_cpp",
|
||||
":aten_native_transformers_cpp",
|
||||
":aten_native_xnnpack",
|
||||
":aten_native_zendnn_cpp",
|
||||
":aten_src_ATen_config",
|
||||
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
|
||||
copts = ATEN_COPTS,
|
||||
|
||||
@ -326,21 +326,6 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN})
|
||||
cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN"
|
||||
OFF)
|
||||
option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF)
|
||||
|
||||
# currently ZenDNN is kept off and enabled only through user setting on X86_64/AMD64
|
||||
option(USE_ZENDNN
|
||||
"Build with ZENDNN support"
|
||||
OFF)
|
||||
if(USE_ZENDNN AND NOT CPU_INTEL)
|
||||
message(WARNING
|
||||
"USE_ZENDNN was requested, but the target processor "
|
||||
"(${CMAKE_SYSTEM_PROCESSOR}) is not AMD64/x86_64. "
|
||||
"ZENDNN support will be disabled.")
|
||||
|
||||
# Switch it off in the cache so the GUI / subsequent runs see the change
|
||||
set(USE_ZENDNN OFF CACHE BOOL "Build with ZENDNN support" FORCE)
|
||||
endif()
|
||||
|
||||
cmake_dependent_option(
|
||||
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
|
||||
"USE_DISTRIBUTED" OFF)
|
||||
@ -1352,7 +1337,6 @@ if(BUILD_SHARED_LIBS)
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/gflags.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/mkl.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/mkldnn.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/zendnn.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/protobuf.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/utils.cmake
|
||||
${PROJECT_SOURCE_DIR}/cmake/public/LoadHIP.cmake
|
||||
|
||||
@ -93,7 +93,6 @@ file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.c
|
||||
file(GLOB native_cpp "native/*.cpp")
|
||||
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
|
||||
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
|
||||
file(GLOB native_zendnn_cpp "native/zendnn/*.cpp")
|
||||
file(GLOB vulkan_cpp "vulkan/*.cpp")
|
||||
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp")
|
||||
|
||||
@ -377,7 +376,7 @@ if(BUILD_LITE_INTERPRETER)
|
||||
append_filelist("aten_native_source_non_codegen_list" all_cpu_cpp)
|
||||
else()
|
||||
set(
|
||||
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_zendnn_cpp}
|
||||
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
|
||||
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
|
||||
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
|
||||
${native_transformers_cpp}
|
||||
|
||||
@ -21,4 +21,3 @@
|
||||
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
|
||||
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
|
||||
#define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@
|
||||
#define AT_ZENDNN_ENABLED() @AT_ZENDNN_ENABLED@
|
||||
|
||||
@ -681,14 +681,6 @@ bool Context::hasEigenSparse() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool Context::hasZenDNN() {
|
||||
#if AT_ZENDNN_ENABLED()
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
at::QEngine Context::qEngine() const {
|
||||
static auto _quantized_engine = []() {
|
||||
at::QEngine qengine = at::kNoQEngine;
|
||||
|
||||
@ -147,7 +147,6 @@ class TORCH_API Context {
|
||||
static bool hasMKL();
|
||||
static bool hasKleidiAI();
|
||||
static bool hasLAPACK();
|
||||
static bool hasZenDNN();
|
||||
static bool hasMKLDNN();
|
||||
static bool ckSupported();
|
||||
static bool hasEigenSparse();
|
||||
@ -627,10 +626,6 @@ inline bool hasEigenSparse() {
|
||||
return globalContext().hasEigenSparse();
|
||||
}
|
||||
|
||||
inline bool hasZenDNN() {
|
||||
return globalContext().hasZenDNN();
|
||||
}
|
||||
|
||||
inline bool hasMAGMA() {
|
||||
return globalContext().hasMAGMA();
|
||||
}
|
||||
|
||||
@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
|
||||
return;
|
||||
}
|
||||
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
bool inplace_op = self.is_same(result);
|
||||
|
||||
bool inplace_update = false;
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
||||
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
|
||||
return std::make_tuple(true, tensor);
|
||||
}
|
||||
|
||||
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
|
||||
if (result.is_same(self)) {
|
||||
at::assert_no_internal_overlap(result);
|
||||
} else {
|
||||
at::assert_no_internal_overlap(result);
|
||||
at::assert_no_overlap(result, self);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TriangularOpsUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
|
||||
|
||||
template <bool upper>
|
||||
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Half,
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/record_function.h>
|
||||
@ -1162,9 +1162,6 @@ def define_buck_targets(
|
||||
"--replace",
|
||||
"@AT_USE_EIGEN_SPARSE@",
|
||||
"0",
|
||||
"--replace",
|
||||
"@AT_ZENDNN_ENABLED@",
|
||||
"0",
|
||||
]),
|
||||
outs = {
|
||||
"Config.h": ["Config.h"],
|
||||
|
||||
@ -1177,7 +1177,6 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/ComparisonUtils.cpp",
|
||||
"aten/src/ATen/native/DispatchStub.cpp",
|
||||
"aten/src/ATen/native/UpSample.cpp",
|
||||
"aten/src/ATen/native/zendnn/Matmul.cpp",
|
||||
"aten/src/ATen/native/mkldnn/BinaryOps.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Conv.cpp",
|
||||
"aten/src/ATen/native/mkldnn/ConvPrepack.cpp",
|
||||
|
||||
@ -67,5 +67,4 @@
|
||||
{"USE_CUSPARSELT", "${USE_CUSPARSELT}"}, \
|
||||
{"USE_XPU", "${USE_XPU}"}, \
|
||||
{"USE_XCCL", "${USE_XCCL}"}, \
|
||||
{"USE_ZENDNN", "${USE_ZENDNN}"} \
|
||||
}
|
||||
|
||||
@ -117,10 +117,6 @@ if(@USE_MKLDNN@)
|
||||
include("${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake")
|
||||
endif()
|
||||
|
||||
if(@USE_ZENDNN@)
|
||||
include("${CMAKE_CURRENT_LIST_DIR}/public/zendnn.cmake")
|
||||
endif()
|
||||
|
||||
# import targets
|
||||
include ("${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets.cmake")
|
||||
|
||||
|
||||
@ -162,7 +162,6 @@ set(AT_MKLDNN_ENABLED 0)
|
||||
set(AT_MKL_ENABLED 0)
|
||||
set(AT_KLEIDIAI_ENABLED 0)
|
||||
set(AT_USE_EIGEN_SPARSE 0)
|
||||
set(AT_ZENDNN_ENABLED 0)
|
||||
# setting default preferred BLAS options if not already present.
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
|
||||
@ -1510,32 +1509,6 @@ if(NOT INTERN_BUILD_MOBILE)
|
||||
message("disabling MKLDNN because USE_MKLDNN is not set")
|
||||
endif()
|
||||
|
||||
if(USE_ZENDNN)
|
||||
if(NOT (CMAKE_SYSTEM_NAME MATCHES "Linux"))
|
||||
message(WARNING
|
||||
"USE_ZENDNN is currently only supported on Linux. Detected platform: ${CMAKE_SYSTEM_NAME}. Disabling ZenDNN support.")
|
||||
set(USE_ZENDNN OFF)
|
||||
elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
message(WARNING
|
||||
"x64 operating system is required for ZenDNN. "
|
||||
"ZenDNN codebase will not be compiled."
|
||||
"Turn this warning off by USE_ZENDNN=OFF.")
|
||||
set(USE_ZENDNN OFF)
|
||||
else()
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/public/zendnn.cmake)
|
||||
if(ZENDNN_FOUND)
|
||||
set(AT_ZENDNN_ENABLED 1)
|
||||
# Add to Caffe2 private dependencies
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS zendnnl::zendnnl_archive)
|
||||
else()
|
||||
message(WARNING "ZENDNN could not be found.")
|
||||
caffe2_update_option(USE_ZENDNN OFF)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "disabling ZENDNN because USE_ZENDNN is not set")
|
||||
endif()
|
||||
|
||||
if(USE_KLEIDIAI)
|
||||
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
|
||||
|
||||
@ -1,402 +0,0 @@
|
||||
include_guard(GLOBAL)
|
||||
include(ExternalProject)
|
||||
|
||||
# declare a zendnnl dependency
|
||||
macro(zendnnl_add_dependency )
|
||||
set(options INCLUDE_ONLY)
|
||||
set(oneValueArgs NAME PATH LIB_SUFFIX INCLUDE_SUFFIX ARCHIVE_FILE ALIAS)
|
||||
set(multiValueArgs DEPENDS)
|
||||
cmake_parse_arguments(_zad "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
string(TOUPPER ${_zad_NAME} _ZAD_UNAME)
|
||||
|
||||
if(DEFINED _zad_INCLUDE_SUFFIX)
|
||||
set(ZENDNNL_${_ZAD_UNAME}_INC_DIR "${_zad_PATH}/${_zad_INCLUDE_SUFFIX}")
|
||||
else()
|
||||
set(ZENDNNL_${_ZAD_UNAME}_INC_DIR "${_zad_PATH}/include")
|
||||
endif()
|
||||
|
||||
if(DEFINED _zad_LIB_SUFFIX)
|
||||
set(ZENDNNL_${_ZAD_UNAME}_LIB_DIR "${_zad_PATH}/${_zad_LIB_SUFFIX}")
|
||||
else()
|
||||
set(ZENDNNL_${_ZAD_UNAME}_LIB_DIR "${_zad_PATH}/lib")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${ZENDNNL_${_ZAD_UNAME}_INC_DIR})
|
||||
file(MAKE_DIRECTORY ${ZENDNNL_${_ZAD_UNAME}_INC_DIR})
|
||||
endif()
|
||||
|
||||
if(${_zad_INCLUDE_ONLY})
|
||||
add_library(zendnnl_${_zad_NAME}_deps INTERFACE IMPORTED GLOBAL)
|
||||
#add_dependencies(zendnnl_${_zad_NAME}_deps ${_zad_DEPENDS})
|
||||
|
||||
set_target_properties(zendnnl_${_zad_NAME}_deps
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}")
|
||||
else()
|
||||
|
||||
add_library(zendnnl_${_zad_NAME}_deps STATIC IMPORTED GLOBAL)
|
||||
#add_dependencies(zendnnl_${_zad_NAME}_deps ${_zad_DEPENDS})
|
||||
|
||||
set_target_properties(zendnnl_${_zad_NAME}_deps
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${ZENDNNL_${_ZAD_UNAME}_LIB_DIR}/${_zad_ARCHIVE_FILE}"
|
||||
INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_${_ZAD_UNAME}_INC_DIR}")
|
||||
endif()
|
||||
|
||||
add_library(${_zad_ALIAS} ALIAS zendnnl_${_zad_NAME}_deps)
|
||||
|
||||
list(APPEND ZNL_BYPRODUCTS "${ZENDNNL_${_ZAD_UNAME}_LIB_DIR}/${_zad_ARCHIVE_FILE}")
|
||||
endmacro()
|
||||
|
||||
macro(zendnnl_add_option )
|
||||
set(options EXECLUDE_FROM_COMMAND_LIST FORCE)
|
||||
set(oneValueArgs NAME VALUE TYPE CACHE_STRING COMMAND_LIST)
|
||||
set(multiValueArgs "")
|
||||
cmake_parse_arguments(_zao "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(${_zao_FORCE})
|
||||
set(${_zao_NAME} ${_zao_VALUE} CACHE ${_zao_TYPE} ${_zao_CACHE_STRING} FORCE)
|
||||
else()
|
||||
set(${_zao_NAME} ${_zao_VALUE} CACHE ${_zao_TYPE} ${_zao_CACHE_STRING})
|
||||
endif()
|
||||
|
||||
if (NOT ${_zao_EXECLUDE_FROM_COMMAND_LIST})
|
||||
list(APPEND ${_zao_COMMAND_LIST} "-D${_zao_NAME}:${_zao_TYPE}=${_zao_VALUE}")
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
message(AUTHOR_WARNING "(ZENDNNL) please ensure all zendnnl variables are set properly.")
|
||||
|
||||
if(NOT ZENDNN_FOUND)
|
||||
# find openmp
|
||||
find_package(OpenMP REQUIRED QUIET)
|
||||
|
||||
# set zendnnl source dir, where zendnnl has been downloaded.
|
||||
zendnnl_add_option(NAME ZENDNNL_SOURCE_DIR
|
||||
VALUE ${PROJECT_SOURCE_DIR}/third_party/ZenDNN
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl_source_dir"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl binary dir, if unsure set ${CMAKE_CURRENT_BINARY_DIR}/zendnnl.
|
||||
zendnnl_add_option(NAME ZENDNNL_BINARY_DIR
|
||||
VALUE ${ZENDNNL_SOURCE_DIR}/build
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl_binary_dir"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl install dir, if unsure set ${CMAKE_INSTALL_PREFIX}/zendnnl.
|
||||
zendnnl_add_option(NAME ZENDNNL_INSTALL_PREFIX
|
||||
VALUE ${ZENDNNL_BINARY_DIR}/install
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl_install_dir"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
## general zendnnl options
|
||||
# set ZenDNNL framework build, this should on ON to avoid standalone build.
|
||||
zendnnl_add_option(NAME ZENDNNL_FWK_BUILD
|
||||
VALUE ON
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl framework build"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl build option, default is Release.
|
||||
zendnnl_add_option(NAME ZENDNNL_BUILD_TYPE
|
||||
VALUE "Release"
|
||||
TYPE STRING
|
||||
CACHE_STRING "zendnnl build type"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl log level.
|
||||
zendnnl_add_option(NAME ZENDNNL_MESSAGE_LOG_LEVEL
|
||||
VALUE "DEBUG"
|
||||
TYPE STRING
|
||||
CACHE_STRING "zendnnl message log level"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl verbose makefile option.
|
||||
zendnnl_add_option(NAME ZENDNNL_VERBOSE_MAKEFILE
|
||||
VALUE ON
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl verbose makefile"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
## components options
|
||||
# set building zendnnl examples, default os OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_BUILD_EXAMPLES
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "build zendnnl examples"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set building zendnnl gtests, default os OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_BUILD_GTEST
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "build zendnnl gtests"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set building zendnnl doxygen documentation, default os OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_BUILD_DOXYGEN
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "build zendnnl doxygen documentation"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set building zendnnl benchmarking tool, default os OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_BUILD_BENCHDNN
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "build zendnnl benchdnn"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set zendnnl code coverage option, default os OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_CODE_COVERAGE
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "build zendnnl code coverage"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
## dependencies
|
||||
# set if zendnnl depends on amdblis. this should bf OFF only if
|
||||
# aocldlp dependency is ON.
|
||||
zendnnl_add_option(NAME ZENDNNL_DEPENDS_AMDBLIS
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl amdblis dependency"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set if zendnnl depends on aocldlp. this should bf ON only if
|
||||
# amdblis dependency is OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_DEPENDS_AOCLDLP
|
||||
VALUE ON
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl aocldlp dependency"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set if zendnnl depends on onednn, default is OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_DEPENDS_ONEDNN
|
||||
VALUE OFF
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl onednn dependency"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set if zendnnl depends on libxsmm, default is OFF.
|
||||
zendnnl_add_option(NAME ZENDNNL_DEPENDS_LIBXSMM
|
||||
VALUE ON
|
||||
TYPE BOOL
|
||||
CACHE_STRING "zendnnl libxsmm dependency"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set path of amdblis if amdblis is injected. if the framework
|
||||
# does not inject it, set it to "" (empty string).
|
||||
zendnnl_add_option(NAME ZENDNNL_AMDBLIS_FWK_DIR
|
||||
VALUE ""
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl amdblis framework path"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set path of aocldlp if aocldlp is injected. if the framework
|
||||
# does not inject it, set it to "" (empty string).
|
||||
zendnnl_add_option(NAME ZENDNNL_AOCLDLP_FWK_DIR
|
||||
VALUE ""
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl aocldlp framework path"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set path of onednn if onednn is injected. if the framework
|
||||
# does not inject it, set it to "" (empty string).
|
||||
zendnnl_add_option(NAME ZENDNNL_ONEDNN_FWK_DIR
|
||||
VALUE ""
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl onednnn framework path"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# set path of libxsmm if libxsmm is injected. if the framework
|
||||
# does not inject it, set it to "" (empty string).
|
||||
zendnnl_add_option(NAME ZENDNNL_LIBXSMM_FWK_DIR
|
||||
VALUE ""
|
||||
TYPE PATH
|
||||
CACHE_STRING "zendnnl libxsmm framework path"
|
||||
COMMAND_LIST ZNL_CMAKE_ARGS)
|
||||
|
||||
# try to find pre-built package
|
||||
set(zendnnl_ROOT "${ZENDNNL_INSTALL_PREFIX}/zendnnl")
|
||||
set(zendnnl_DIR "${zendnnl_ROOT}/lib/cmake")
|
||||
find_package(zendnnl QUIET)
|
||||
if(zendnnl_FOUND)
|
||||
message(STATUS "(ZENDNNL) ZENDNNL FOUND AT ${zendnnl_ROOT}")
|
||||
message(STATUS "(ZENDNNL) if zendnnl options are changed from previous build,")
|
||||
message(STATUS "(ZENDNNL) they will not be reflected")
|
||||
message(STATUS "(ZENDNNL) If options are changed, please do a clean build.")
|
||||
if(TARGET zendnnl::zendnnl_archive)
|
||||
set_target_properties(zendnnl::zendnnl_archive
|
||||
PROPERTIES IMPORTED_GLOBAL ON)
|
||||
else()
|
||||
message(FATAL_ERROR "(ZENDNNL) zendnnl installation does not have imported target zendnnl::zendnnl_archive")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "(ZENDNNL) ZENDNNL NOT FOUND, will be built as an external project.")
|
||||
|
||||
# declare zendnnl library
|
||||
set(ZENDNNL_LIBRARY_INC_DIR "${ZENDNNL_INSTALL_PREFIX}/zendnnl/include")
|
||||
set(ZENDNNL_LIBRARY_LIB_DIR "${ZENDNNL_INSTALL_PREFIX}/zendnnl/lib")
|
||||
|
||||
if(NOT EXISTS ${ZENDNNL_LIBRARY_INC_DIR})
|
||||
file(MAKE_DIRECTORY ${ZENDNNL_LIBRARY_INC_DIR})
|
||||
endif()
|
||||
|
||||
add_library(zendnnl_library STATIC IMPORTED GLOBAL)
|
||||
add_dependencies(zendnnl_library fwk_zendnnl)
|
||||
set_target_properties(zendnnl_library
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${ZENDNNL_LIBRARY_LIB_DIR}/libzendnnl_archive.a"
|
||||
INCLUDE_DIRECTORIES "${ZENDNNL_LIBRARY_INC_DIR}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${ZENDNNL_LIBRARY_INC_DIR}")
|
||||
|
||||
target_link_options(zendnnl_library INTERFACE "-fopenmp")
|
||||
target_link_libraries(zendnnl_library
|
||||
INTERFACE OpenMP::OpenMP_CXX
|
||||
INTERFACE ${CMAKE_DL_LIBS})
|
||||
|
||||
add_library(zendnnl::zendnnl_archive ALIAS zendnnl_library)
|
||||
|
||||
list(APPEND ZNL_BYPRODUCTS "${ZENDNNL_LIBRARY_LIB_DIR}/libzendnnl_archive.a")
|
||||
|
||||
# declare all dependencies
|
||||
|
||||
# json dependency
|
||||
zendnnl_add_dependency(NAME json
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/json"
|
||||
ALIAS "nlohmann_json::nlohmann_json"
|
||||
INCLUDE_ONLY)
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE nlohmann_json::nlohmann_json)
|
||||
|
||||
# aoclutils dependency
|
||||
if (DEFINED ENV{ZENDNNL_MANYLINUX_BUILD})
|
||||
|
||||
zendnnl_add_dependency(NAME aoclutils
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
|
||||
LIB_SUFFIX lib64
|
||||
ARCHIVE_FILE "libaoclutils.a"
|
||||
ALIAS "au::aoclutils")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE au::aoclutils)
|
||||
|
||||
zendnnl_add_dependency(NAME aucpuid
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
|
||||
LIB_SUFFIX lib64
|
||||
ARCHIVE_FILE "libau_cpuid.a"
|
||||
ALIAS "au::au_cpuid")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE au::au_cpuid)
|
||||
|
||||
else()
|
||||
zendnnl_add_dependency(NAME aoclutils
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
|
||||
ARCHIVE_FILE "libaoclutils.a"
|
||||
ALIAS "au::aoclutils")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE au::aoclutils)
|
||||
|
||||
zendnnl_add_dependency(NAME aucpuid
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aoclutils"
|
||||
ARCHIVE_FILE "libau_cpuid.a"
|
||||
ALIAS "au::au_cpuid")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE au::au_cpuid)
|
||||
|
||||
endif()
|
||||
|
||||
# amdblis dependency
|
||||
if (ZENDNNL_DEPENDS_AMDBLIS)
|
||||
zendnnl_add_dependency(NAME amdblis
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/amdblis"
|
||||
ARCHIVE_FILE "libblis-mt.a"
|
||||
ALIAS "amdblis::amdblis_archive")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE amdblis::amdblis_archive)
|
||||
endif()
|
||||
|
||||
if (ZENDNNL_DEPENDS_AOCLDLP)
|
||||
zendnnl_add_dependency(NAME aocldlp
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/aocldlp"
|
||||
ARCHIVE_FILE "libaocl-dlp.a"
|
||||
ALIAS "aocldlp::aocl_dlp_static")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE aocldlp::aocl_dlp_static)
|
||||
endif()
|
||||
|
||||
if (ZENDNNL_DEPENDS_ONEDNN)
|
||||
zendnnl_add_dependency(NAME onednn
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/onednn"
|
||||
ARCHIVE_FILE "libdnnl.a"
|
||||
ALIAS "DNNL::dnnl")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE DNNL::dnnl)
|
||||
endif()
|
||||
|
||||
# libxsmm dependency
|
||||
if (ZENDNNL_DEPENDS_LIBXSMM)
|
||||
zendnnl_add_dependency(NAME libxsmm
|
||||
PATH "${ZENDNNL_INSTALL_PREFIX}/deps/libxsmm"
|
||||
ARCHIVE_FILE "libxsmm.a"
|
||||
ALIAS "libxsmm::libxsmm_archive")
|
||||
|
||||
target_link_libraries(zendnnl_library INTERFACE libxsmm::libxsmm_archive)
|
||||
endif()
|
||||
|
||||
message(STATUS "(ZENDNNL) ZNL_BYPRODUCTS=${ZNL_BYPRODUCTS}")
|
||||
message(STATUS "(ZENDNNL) ZNL_CMAKE_ARGS=${ZNL_CMAKE_ARGS}")
|
||||
|
||||
ExternalProject_ADD(fwk_zendnnl
|
||||
SOURCE_DIR "${ZENDNNL_SOURCE_DIR}"
|
||||
BINARY_DIR "${ZENDNNL_BINARY_DIR}"
|
||||
CMAKE_ARGS "${ZNL_CMAKE_ARGS}"
|
||||
BUILD_COMMAND cmake --build . --target all -j
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${ZNL_BYPRODUCTS})
|
||||
|
||||
list(APPEND ZENDNNL_CLEAN_FILES "${ZENDNNL_BINARY_DIR}")
|
||||
list(APPEND ZENDNNL_CLEAN_FILES "${ZENDNNL_INSTALL_PREFIX}")
|
||||
set_target_properties(fwk_zendnnl
|
||||
PROPERTIES
|
||||
ADDITIONAL_CLEAN_FILES "${ZENDNNL_CLEAN_FILES}")
|
||||
|
||||
# framework dependencies
|
||||
# add_dependencies(fwk_zendnnl <injected dependency targets>)
|
||||
get_target_property(FWK_ZENDNNL_DEPENDS fwk_zendnnl MANUALLY_ADDED_DEPENDENCIES)
|
||||
if(${FWK_ZENDNNL_DEPENDS} STREQUAL "FWK_ZENDNNL_DEPENDS-NOTFOUND")
|
||||
message(AUTHOR_WARNING "(ZENDNNL) please ensure fwk_zendnnl depends on injected dependencies targets")
|
||||
else()
|
||||
message(STATUS "fwk_zendnnl dependencies : ${FWK_ZENDNNL_DEPENDS}")
|
||||
endif()
|
||||
|
||||
# make library and its dependencies depend on fwk_zendnnl
|
||||
add_dependencies(zendnnl_library fwk_zendnnl)
|
||||
add_dependencies(zendnnl_json_deps fwk_zendnnl)
|
||||
add_dependencies(zendnnl_aoclutils_deps fwk_zendnnl)
|
||||
add_dependencies(zendnnl_aucpuid_deps fwk_zendnnl)
|
||||
|
||||
if(ZENDNNL_DEPENDS_AMDBLIS)
|
||||
add_dependencies(zendnnl_amdblis_deps fwk_zendnnl)
|
||||
endif()
|
||||
|
||||
if(ZENDNNL_DEPENDS_AOCLDLP)
|
||||
add_dependencies(zendnnl_aocldlp_deps fwk_zendnnl)
|
||||
endif()
|
||||
|
||||
if(ZENDNNL_DEPENDS_ONEDNN)
|
||||
add_dependencies(zendnnl_onednn_deps fwk_zendnnl)
|
||||
endif()
|
||||
|
||||
if(ZENDNNL_DEPENDS_LIBXSMM)
|
||||
add_dependencies(zendnnl_libxsmm_deps fwk_zendnnl)
|
||||
endif()
|
||||
endif()
|
||||
set(ZENDNN_FOUND TRUE)
|
||||
|
||||
endif(NOT ZENDNN_FOUND)
|
||||
@ -148,7 +148,6 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_PYTORCH_METAL_EXPORT : ${USE_PYTORCH_METAL_EXPORT}")
|
||||
message(STATUS " USE_MPS : ${USE_MPS}")
|
||||
message(STATUS " CAN_COMPILE_METAL : ${CAN_COMPILE_METAL}")
|
||||
message(STATUS " USE_ZENDNN : ${USE_ZENDNN}")
|
||||
message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}")
|
||||
if(${CAFFE2_USE_MKL})
|
||||
message(STATUS " USE_STATIC_MKL : ${USE_STATIC_MKL}")
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/ZenDNN)
|
||||
message(WARNING "(ZENDNNL) Library not found at ${PROJECT_SOURCE_DIR}/third_party/ZenDNN")
|
||||
else()
|
||||
find_package(ZENDNN QUIET)
|
||||
if(ZENDNN_FOUND)
|
||||
message(STATUS, "(ZENDNN) ZenDNN library was built successfully.")
|
||||
endif(ZENDNN_FOUND)
|
||||
endif()
|
||||
7
setup.py
7
setup.py
@ -67,9 +67,6 @@
|
||||
# USE_NUMPY=0
|
||||
# disables the NumPy build
|
||||
#
|
||||
# USE_ZENDNN=0
|
||||
# disables the ZenDNN build
|
||||
#
|
||||
# BUILD_TEST=0
|
||||
# disables the test build
|
||||
#
|
||||
@ -1193,10 +1190,6 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
report("-- Not using CBLAS in MKLDNN")
|
||||
else:
|
||||
report("-- Not using MKLDNN")
|
||||
if cmake_cache_vars["USE_ZENDNN"]:
|
||||
report("-- Using ZENDNN")
|
||||
else:
|
||||
report("-- Not using ZENDNN")
|
||||
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
|
||||
report(
|
||||
"-- Using system provided NCCL library at "
|
||||
|
||||
@ -424,7 +424,7 @@ from user code:
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
fn(torch.randn(4))
|
||||
@ -434,10 +434,10 @@ from user code:
|
||||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
pytree_modules["cxx"] = cxx_pytree
|
||||
pytree_modules["native_optree"] = cxx_pytree.optree
|
||||
else:
|
||||
cxx_pytree = None
|
||||
|
||||
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
torch.ones(3, 2),
|
||||
1,
|
||||
]
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
if pytree.__name__ == "optree":
|
||||
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
|
||||
new_leaves.pop()
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
new_tree = pytree.tree_unflatten(treespec, new_leaves)
|
||||
else:
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
return leaves, new_tree
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||
# OpTree and JAX PyTree do not have `tree_map_only`
|
||||
return
|
||||
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
||||
@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.assertEqual(result_triu_min, expected_triu_min)
|
||||
self.assertEqual(result_tril_min, expected_tril_min)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
|
||||
base = torch.rand((), dtype=dtype, device=device)
|
||||
expanded = base.expand(3, 3)
|
||||
msg = (
|
||||
"unsupported operation: more than one element of the written-to tensor "
|
||||
"refers to a single memory location. Please clone() the tensor before "
|
||||
"performing the operation."
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.triu_(1)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.tril_(-1)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
|
||||
@ -140,7 +140,6 @@ class TestPublicBindings(TestCase):
|
||||
"has_mps",
|
||||
"has_openmp",
|
||||
"has_spectral",
|
||||
"has_zendnn",
|
||||
"iinfo",
|
||||
"import_ir_module_from_buffer",
|
||||
"import_ir_module",
|
||||
|
||||
1
third_party/ZenDNN
vendored
1
third_party/ZenDNN
vendored
Submodule third_party/ZenDNN deleted from af92954683
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import optree
|
||||
import optree._C
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> bool:
|
||||
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
||||
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||
return True
|
||||
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||
return True
|
||||
return False
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
||||
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> Iterable[Any]:
|
||||
stack = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
yield node
|
||||
continue
|
||||
|
||||
children, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
stack.extend(reversed(children))
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> list[Any]:
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
return list(
|
||||
tree_iter(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
)
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
_metadata: Any
|
||||
_entries: tuple[Any, ...]
|
||||
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
||||
none_is_leaf: bool
|
||||
namespace: str
|
||||
|
||||
num_nodes: int = field(init=False)
|
||||
num_leaves: int = field(init=False)
|
||||
num_children: int = field(init=False)
|
||||
none_is_leaf: Literal[True] = field(init=False)
|
||||
namespace: Literal["torch"] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._type is None:
|
||||
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
object.__setattr__(self, "num_nodes", num_nodes)
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
object.__setattr__(self, "none_is_leaf", True)
|
||||
object.__setattr__(self, "namespace", "torch")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
]
|
||||
if (
|
||||
treespec.type in BUILTIN_TYPES
|
||||
or (treespec.type is type(None) and not self.none_is_leaf)
|
||||
or optree.is_namedtuple_class(treespec.type)
|
||||
or optree.is_structseq_class(treespec.type)
|
||||
):
|
||||
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
f"[{', '.join(children_representations)}])"
|
||||
)
|
||||
|
||||
return (
|
||||
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
||||
)
|
||||
inner = [
|
||||
str(helper(self)),
|
||||
*(["NoneIsLeaf"] if self.none_is_leaf else []),
|
||||
f"namespace={self.namespace!r}",
|
||||
]
|
||||
return f"PyTreeSpec({', '.join(inner)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_leaves
|
||||
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
# node_type is treespec.type
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if (
|
||||
node_type
|
||||
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
assert callable(self._unflatten_func)
|
||||
return self._unflatten_func(self._metadata, subtrees)
|
||||
|
||||
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_flatten,
|
||||
optree.tree_flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[Any], PyTreeSpec]:
|
||||
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
leaves.append(node)
|
||||
return _LEAF_SPEC
|
||||
return PyTreeSpec(
|
||||
(),
|
||||
None,
|
||||
None,
|
||||
(),
|
||||
None,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
(
|
||||
children,
|
||||
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
) = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
# Recursively flatten the children
|
||||
subspecs = tuple(helper(child, leaves) for child in children)
|
||||
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
||||
return PyTreeSpec(
|
||||
subspecs,
|
||||
type(node),
|
||||
metadata,
|
||||
entries,
|
||||
unflatten_func,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
leaves: list[Any] = []
|
||||
treespec = helper(tree, leaves)
|
||||
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
__all__ += ["tree_flatten"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_structure,
|
||||
optree.tree_structure,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTreeSpec:
|
||||
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
||||
return tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)[1]
|
||||
|
||||
__all__ += ["tree_structure"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_unflatten,
|
||||
optree.tree_unflatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_none_unflatten,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
||||
if len(list(children)) != 0:
|
||||
raise ValueError("Expected no children.")
|
||||
return None
|
||||
|
||||
@ -264,7 +264,8 @@ class OutputComparisonLogger(OutputLogger):
|
||||
# fmt: on
|
||||
if not self.enabled:
|
||||
return x
|
||||
assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
|
||||
if not isinstance(x, torch.Tensor):
|
||||
raise AssertionError("non-tensor inputs not yet supported")
|
||||
if self.save_activations:
|
||||
# save the activation, for debugging
|
||||
self.stats.append(x.detach())
|
||||
@ -595,9 +596,8 @@ def _extract_logger_info_one_model(
|
||||
key = mod.ref_name
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
assert mod.model_name not in results[key], (
|
||||
f"{mod.model_name} is already present in results"
|
||||
)
|
||||
if mod.model_name in results[key]:
|
||||
raise AssertionError(f"{mod.model_name} is already present in results")
|
||||
if mod.results_type not in results[key]:
|
||||
results[key][mod.results_type] = {}
|
||||
if mod.model_name not in results[key][mod.results_type]:
|
||||
@ -809,12 +809,10 @@ def extend_logger_results_with_comparison(
|
||||
"""
|
||||
for results_type_to_results in results.values():
|
||||
for model_name_to_results in results_type_to_results.values():
|
||||
assert model_name_1 in model_name_to_results, (
|
||||
f"{model_name_1} not found in results"
|
||||
)
|
||||
assert model_name_2 in model_name_to_results, (
|
||||
f"{model_name_2} not found in results"
|
||||
)
|
||||
if model_name_1 not in model_name_to_results:
|
||||
raise AssertionError(f"{model_name_1} not found in results")
|
||||
if model_name_2 not in model_name_to_results:
|
||||
raise AssertionError(f"{model_name_2} not found in results")
|
||||
|
||||
results_1 = model_name_to_results[model_name_1]
|
||||
results_2 = model_name_to_results[model_name_2]
|
||||
@ -832,7 +830,8 @@ def extend_logger_results_with_comparison(
|
||||
):
|
||||
result_1 = cur_result_1
|
||||
break
|
||||
assert result_1 is not None
|
||||
if result_1 is None:
|
||||
raise AssertionError("Expected result_1 to be not None")
|
||||
|
||||
values_1 = result_1["values"]
|
||||
values_2 = result_2["values"]
|
||||
|
||||
@ -150,7 +150,8 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||
if node.op == "call_function":
|
||||
return node.target not in self.non_matchable_functions
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
target_mod = getattr_from_fqn(self.gm, node.target)
|
||||
return not any(
|
||||
isinstance(target_mod, t) # type: ignore[arg-type]
|
||||
@ -228,16 +229,19 @@ def _get_subgraph_relationship_type(
|
||||
else:
|
||||
return SubgraphTypeRelationship.NOT_RELATED
|
||||
elif node_a.op == "call_module":
|
||||
assert (
|
||||
subgraph_a.base_op_node == subgraph_a.start_node
|
||||
and subgraph_b.base_op_node == subgraph_b.start_node
|
||||
), (
|
||||
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
||||
)
|
||||
if (
|
||||
subgraph_a.base_op_node != subgraph_a.start_node
|
||||
or subgraph_b.base_op_node != subgraph_b.start_node
|
||||
):
|
||||
raise AssertionError(
|
||||
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
||||
)
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node_a.target, str)
|
||||
if not isinstance(node_a.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_a.target)}")
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
assert isinstance(node_b.target, str)
|
||||
if not isinstance(node_b.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_b.target)}")
|
||||
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
||||
|
||||
key = (type(mod_a), type(mod_b))
|
||||
@ -312,7 +316,8 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetT
|
||||
if node.op in ("call_function", "call_method"):
|
||||
return node.target
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
return type(mod)
|
||||
return None
|
||||
@ -452,9 +457,10 @@ of subgraphs, and each pair of subgraphs is related to each other."""
|
||||
key_name_b = _get_name_for_subgraph(
|
||||
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
|
||||
)
|
||||
assert key_name_a == key_name_b, (
|
||||
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
||||
)
|
||||
if key_name_a != key_name_b:
|
||||
raise AssertionError(
|
||||
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
||||
)
|
||||
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
|
||||
continue
|
||||
elif cur_subgraph_a is None and cur_subgraph_b is None:
|
||||
|
||||
@ -32,7 +32,8 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
|
||||
# an observer, get the fqn of the node being observed.
|
||||
node_to_use_for_fqn = node
|
||||
if node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
module = getattr_from_fqn(gm, node.target)
|
||||
if _is_activation_post_process(module):
|
||||
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
|
||||
@ -348,7 +349,8 @@ def _insert_dtype_cast_after_node(
|
||||
new_dtype_cast_name,
|
||||
)
|
||||
else:
|
||||
assert dtype_cast_mod_cls
|
||||
if not dtype_cast_mod_cls:
|
||||
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
|
||||
dtype_cast_mod = dtype_cast_mod_cls()
|
||||
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
||||
return graph_c.create_node(
|
||||
@ -373,7 +375,8 @@ def _insert_dtype_cast_after_node(
|
||||
)
|
||||
results.append(new_dtype_cast_node)
|
||||
else:
|
||||
assert dtype_cast_mod_cls
|
||||
if not dtype_cast_mod_cls:
|
||||
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
|
||||
dtype_cast_mod = dtype_cast_mod_cls()
|
||||
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
||||
new_dtype_cast_node = graph_c.create_node(
|
||||
@ -412,10 +415,8 @@ def _copy_node_from_a_to_c(
|
||||
)
|
||||
return node_a_copy
|
||||
elif node_a.op == "call_method":
|
||||
assert node_a.target in (
|
||||
"dequantize",
|
||||
"to",
|
||||
), f"target {node_a.target} is not implemented"
|
||||
if node_a.target not in ("dequantize", "to"):
|
||||
raise AssertionError(f"target {node_a.target} is not implemented")
|
||||
if node_a.target == "dequantize":
|
||||
arg_copy = _copy_node_from_a_to_c(
|
||||
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
|
||||
@ -535,7 +536,8 @@ def _insert_copy_of_subgraph_a_after_input_node_c(
|
||||
"""
|
||||
TODO(before land): real docblock
|
||||
"""
|
||||
assert isinstance(input_node_c, (Node, list))
|
||||
if not isinstance(input_node_c, (Node, list)):
|
||||
raise AssertionError(f"Expected Node or list, got {type(input_node_c)}")
|
||||
|
||||
# create a sequential list of the subgraphs' nodes from start to end,
|
||||
# because we need to add the nodes to graph C in non-reverse order
|
||||
@ -621,7 +623,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
if isinstance(input_node_c, Node):
|
||||
graph_c = input_node_c.graph
|
||||
else:
|
||||
assert isinstance(input_node_c, list)
|
||||
if not isinstance(input_node_c, list):
|
||||
raise AssertionError(f"Expected list, got {type(input_node_c)}")
|
||||
graph_c = input_node_c[0].graph
|
||||
|
||||
norm_args_kwargs = node_a.normalized_arguments(
|
||||
@ -645,9 +648,10 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
return arg
|
||||
elif isinstance(kwarg_val, (list, tuple)):
|
||||
for el in kwarg_val:
|
||||
assert not isinstance(el, Node), (
|
||||
"handling of Node inside list is not implemented"
|
||||
)
|
||||
if isinstance(el, Node):
|
||||
raise AssertionError(
|
||||
"handling of Node inside list is not implemented"
|
||||
)
|
||||
return arg
|
||||
else:
|
||||
raise AssertionError(
|
||||
@ -684,7 +688,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
# if target is a module, we point to the module from gm_b
|
||||
new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
||||
# fetch the corresponding module from gm_a
|
||||
assert isinstance(node_a.target, str)
|
||||
if not isinstance(node_a.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_a.target)}")
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
setattr(gm_b, new_mod_copy_name, mod_a)
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
@ -696,7 +701,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
)
|
||||
return node_a_shadows_c
|
||||
else:
|
||||
assert node_a.op in ("call_function", "call_method")
|
||||
if node_a.op not in ("call_function", "call_method"):
|
||||
raise AssertionError(f"Unexpected op: {node_a.op}")
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op,
|
||||
node_a.target,
|
||||
@ -791,7 +797,8 @@ def create_a_shadows_b(
|
||||
ref_node_type_b,
|
||||
) = start_node_b_to_matched_subgraph_a_and_name[node_b]
|
||||
else:
|
||||
assert node_b_is_end_node
|
||||
if not node_b_is_end_node:
|
||||
raise AssertionError("Expected node_b_is_end_node to be not false")
|
||||
(
|
||||
subgraph_a,
|
||||
ref_name,
|
||||
@ -1001,7 +1008,10 @@ def create_a_shadows_b(
|
||||
)
|
||||
input_logger: Union[Node, list[Node]] = dtype_cast_node
|
||||
else:
|
||||
assert isinstance(dtype_cast_node, list)
|
||||
if not isinstance(dtype_cast_node, list):
|
||||
raise AssertionError(
|
||||
f"Expected list, got {type(dtype_cast_node)}"
|
||||
)
|
||||
new_loggers = []
|
||||
for dtype_cast_idx, dtype_cast_node_inner in enumerate(
|
||||
dtype_cast_node
|
||||
@ -1083,7 +1093,10 @@ def create_a_shadows_b(
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert isinstance(input_logger, list)
|
||||
if not isinstance(input_logger, list):
|
||||
raise AssertionError(
|
||||
f"Expected list, got {type(input_logger)}"
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
for input_logger_inner in input_logger:
|
||||
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
||||
|
||||
@ -144,9 +144,11 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
seen_nodes.add(node_or_tuple)
|
||||
|
||||
else:
|
||||
assert isinstance(node_or_tuple, tuple)
|
||||
if not isinstance(node_or_tuple, tuple):
|
||||
raise AssertionError(f"Expected tuple, got {type(node_or_tuple)}")
|
||||
for node in node_or_tuple:
|
||||
assert isinstance(node, Node)
|
||||
if not isinstance(node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node)}")
|
||||
if node in seen_nodes:
|
||||
was_seen = True
|
||||
seen_nodes.add(node)
|
||||
@ -160,7 +162,10 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
if len(cur_match[1]) == 1:
|
||||
list_of_nodes = cur_match[1]
|
||||
else:
|
||||
assert len(cur_match[1]) == 2
|
||||
if len(cur_match[1]) != 2:
|
||||
raise ValueError(
|
||||
f"Expected cur_match[1] to have length 2, got {len(cur_match[1])}"
|
||||
)
|
||||
# either (a, b), or ((a, b), c) or (c, (a, b))
|
||||
# cannot make any assumptions on order, not clear what the
|
||||
# _find_matches function is doing to populate this
|
||||
@ -181,13 +186,12 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
last_node = n
|
||||
else:
|
||||
mid_node = n
|
||||
assert (
|
||||
first_node is not None
|
||||
and mid_node is not None
|
||||
and last_node is not None
|
||||
)
|
||||
assert mid_node.args[0] is first_node
|
||||
assert last_node.args[0] is mid_node
|
||||
if first_node is None or mid_node is None or last_node is None:
|
||||
raise AssertionError("Expected all nodes to be non-None")
|
||||
if mid_node.args[0] is not first_node:
|
||||
raise AssertionError("Expected mid_node.args[0] to be first_node")
|
||||
if last_node.args[0] is not mid_node:
|
||||
raise AssertionError("Expected last_node.args[0] to be mid_node")
|
||||
return [last_node, mid_node, first_node]
|
||||
|
||||
if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
|
||||
@ -377,7 +381,10 @@ def create_submodule_from_subgraph(
|
||||
# the current implementation is simplistic and cannot handle
|
||||
# ops with two or more arguments which need to be passed from
|
||||
# the previous op, so we assert them out
|
||||
assert cur_node_orig.target not in BINARY_FUNCTIONS
|
||||
if cur_node_orig.target in BINARY_FUNCTIONS:
|
||||
raise AssertionError(
|
||||
f"Unexpected binary function target: {cur_node_orig.target}"
|
||||
)
|
||||
|
||||
# at this point in the code, cur_node_copy is pointing to the copy
|
||||
# of the previous node
|
||||
@ -435,9 +442,10 @@ def create_submodule_from_subgraph(
|
||||
break
|
||||
|
||||
# go to next node
|
||||
assert len(cur_node_orig.users.keys()) == 1, (
|
||||
f"{cur_node_orig} has more than 1 users, not supported yet"
|
||||
)
|
||||
if len(cur_node_orig.users.keys()) != 1:
|
||||
raise AssertionError(
|
||||
f"{cur_node_orig} has more than 1 users, not supported yet"
|
||||
)
|
||||
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
||||
cur_iteration += 1
|
||||
if cur_iteration > iteration_limit:
|
||||
@ -494,7 +502,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
)
|
||||
|
||||
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, logger_mod_orig)
|
||||
with mt.graph.inserting_after(last_node):
|
||||
new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
|
||||
@ -537,9 +546,10 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
"prepare_custom_config",
|
||||
"qconfig_mapping",
|
||||
]:
|
||||
assert kwarg_name not in custom_prepare_kwargs, (
|
||||
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
|
||||
)
|
||||
if kwarg_name in custom_prepare_kwargs:
|
||||
raise AssertionError(
|
||||
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
|
||||
)
|
||||
prepare_kwargs: dict[str, Any] = {
|
||||
"example_inputs": example_inputs,
|
||||
"qconfig_mapping": qconfig_mapping,
|
||||
@ -551,7 +561,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
|
||||
# attach the wrapper to the model
|
||||
attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, orig_mod_copy_wrapped)
|
||||
|
||||
# add a call to the wrapper module from the parent graph
|
||||
@ -600,7 +611,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
)
|
||||
|
||||
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, logger_mod_orig)
|
||||
with mt.graph.inserting_after(new_node):
|
||||
logger = mt.graph.call_module(
|
||||
@ -824,7 +836,8 @@ def create_add_loggers_graph(
|
||||
):
|
||||
new_shadow_mod = maybe_shadow_mod
|
||||
break
|
||||
assert new_shadow_mod is not None
|
||||
if new_shadow_mod is None:
|
||||
raise AssertionError("Expected new_shadow_mod to be non-None")
|
||||
orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
|
||||
orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
|
||||
|
||||
@ -850,7 +863,10 @@ def create_add_loggers_graph(
|
||||
fqn,
|
||||
)
|
||||
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(model, attr_name)
|
||||
if hasattr(model, attr_name):
|
||||
raise AssertionError(
|
||||
f"Unexpected attribute '{attr_name}' found in {model}"
|
||||
)
|
||||
setattr(model, attr_name, logger_mod_orig)
|
||||
insertion_point = last_node
|
||||
with model.graph.inserting_after(insertion_point):
|
||||
@ -887,9 +903,15 @@ def create_add_loggers_graph(
|
||||
# since now only linear subgraphs are supported, all nodes
|
||||
# except the last one must have only one user
|
||||
if cur_node_orig != last_node:
|
||||
assert len(cur_node_orig.users.keys()) == 1
|
||||
if len(cur_node_orig.users.keys()) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected exactly 1, but got {len(cur_node_orig.users)}"
|
||||
)
|
||||
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
||||
assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
|
||||
if cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX):
|
||||
raise AssertionError(
|
||||
"cur_node_orig should not start with SHADOW_NODE_NAME_PREFIX"
|
||||
)
|
||||
insertion_point = cur_node_copy
|
||||
|
||||
# add a comparison logger after last_node's copy
|
||||
@ -905,7 +927,10 @@ def create_add_loggers_graph(
|
||||
fqn,
|
||||
)
|
||||
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(model, attr_name)
|
||||
if hasattr(model, attr_name):
|
||||
raise AssertionError(
|
||||
f"Unexpected attribute '{attr_name}' found in {model}"
|
||||
)
|
||||
setattr(model, attr_name, logger_mod_orig)
|
||||
with model.graph.inserting_after(insertion_point):
|
||||
logger = model.graph.call_module(
|
||||
@ -979,7 +1004,8 @@ def create_add_loggers_graph(
|
||||
return prev_shadow_output
|
||||
|
||||
cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
|
||||
assert cur_shadow_input is not None
|
||||
if cur_shadow_input is None:
|
||||
raise AssertionError("Expected cur_shadow_input to be non-None")
|
||||
cur_shadow_input.args = tree_map(
|
||||
maybe_remap_node_to_shadow, cur_shadow_input.args
|
||||
)
|
||||
@ -1019,7 +1045,8 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
||||
# we have `w2_0`, and are navigating this subgraph
|
||||
# to get `_input_scale_1` and `_input_zero_point_1`
|
||||
|
||||
assert len(shadow_n.users) == 1
|
||||
if len(shadow_n.users) != 1:
|
||||
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
|
||||
quant_node = next(iter(shadow_n.users.keys()))
|
||||
new_args: Any = None
|
||||
if quant_node.target == torch.quantize_per_channel:
|
||||
@ -1028,7 +1055,10 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
||||
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
|
||||
new_args = (scale_val, zp_val, axis, dtype)
|
||||
else:
|
||||
assert quant_node.target == torch.quantize_per_tensor
|
||||
if quant_node.target != torch.quantize_per_tensor:
|
||||
raise AssertionError(
|
||||
f"Expected torch.quantize_per_tensor, but got {quant_node.target}"
|
||||
)
|
||||
_weight, scale_node, zp_node, dtype = quant_node.args
|
||||
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
|
||||
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
|
||||
|
||||
@ -167,7 +167,8 @@ def end_node_matches_reversed_fusion(
|
||||
elif cur_node.op == "call_module":
|
||||
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
||||
if fusion_el_is_mod:
|
||||
assert isinstance(cur_node.target, str)
|
||||
if not isinstance(cur_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(cur_node.target)}")
|
||||
target_mod = getattr_from_fqn(gm, cur_node.target)
|
||||
if not isinstance(cur_fusion_el, type):
|
||||
return False
|
||||
@ -190,7 +191,10 @@ def end_node_matches_reversed_fusion(
|
||||
if cur_node.target != cur_fusion_el:
|
||||
return False
|
||||
else:
|
||||
assert isinstance(cur_fusion_el, tuple)
|
||||
if not isinstance(cur_fusion_el, tuple):
|
||||
raise AssertionError(
|
||||
f"Expected tuple, got {type(cur_fusion_el)}"
|
||||
)
|
||||
if cur_node.target != cur_fusion_el[0]:
|
||||
return False
|
||||
elif len(cur_node.args) < 2:
|
||||
|
||||
@ -61,7 +61,8 @@ def get_node_first_input_and_output_type(
|
||||
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
||||
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -73,8 +74,11 @@ def get_node_first_input_and_output_type(
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
|
||||
elif node.op == "call_module":
|
||||
assert node.op == "call_module"
|
||||
assert isinstance(node.target, str)
|
||||
if node.op != "call_module":
|
||||
raise AssertionError(f"Expected call_module, got '{node.op}'")
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, but got {type(node.target)}")
|
||||
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
is_known_fp32_or_int8_input_module = any(
|
||||
isinstance(mod, target_type) # type: ignore[arg-type]
|
||||
@ -87,7 +91,8 @@ def get_node_first_input_and_output_type(
|
||||
# A logger or observer's input and output type is the output
|
||||
# type of the preceding node.
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -116,7 +121,8 @@ def get_node_first_input_and_output_type(
|
||||
# So, we look up the output type of the previous node and return that
|
||||
# as the input type of this node instance.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
if not isinstance(prev_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(prev_node)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -131,7 +137,8 @@ def get_node_first_input_and_output_type(
|
||||
# as the input type of this node instance. We also look up the target
|
||||
# of to and return the correct output type.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
if not isinstance(prev_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(prev_node)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -140,15 +147,17 @@ def get_node_first_input_and_output_type(
|
||||
)
|
||||
|
||||
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
|
||||
assert cur_node_dtype_target is torch.float16, (
|
||||
f"{cur_node_dtype_target} handling needs to be added"
|
||||
)
|
||||
if cur_node_dtype_target is not torch.float16:
|
||||
raise AssertionError(
|
||||
f"{cur_node_dtype_target} handling needs to be added"
|
||||
)
|
||||
|
||||
return (prev_node_output_type, NodeInputOrOutputType.FP16)
|
||||
|
||||
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -181,8 +190,14 @@ def get_node_input_qparams(
|
||||
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
|
||||
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
|
||||
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
|
||||
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
|
||||
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
|
||||
if not isinstance(scale_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(scale_node)}")
|
||||
if not isinstance(scale_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(scale_node.target)}")
|
||||
if not isinstance(zp_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(zp_node)}")
|
||||
if not isinstance(zp_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(zp_node.target)}")
|
||||
scale_obj = getattr_from_fqn(gm, scale_node.target)
|
||||
zp_obj = getattr_from_fqn(gm, zp_node.target)
|
||||
return (scale_obj, zp_obj)
|
||||
@ -200,7 +215,8 @@ def get_node_input_qparams(
|
||||
|
||||
elif prev_node.op == "call_module":
|
||||
# get type of the module
|
||||
assert isinstance(prev_node.target, str)
|
||||
if not isinstance(prev_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(prev_node.target)}")
|
||||
module_obj = getattr_from_fqn(gm, prev_node.target)
|
||||
if isinstance(
|
||||
module_obj,
|
||||
@ -259,15 +275,24 @@ def return_first_non_observer_node(
|
||||
if node.op == "call_module":
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
if len(node.args) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected node.args to have length 1, got {len(node.args)}"
|
||||
)
|
||||
if not isinstance(node.args[0], Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
|
||||
node = node.args[0]
|
||||
# code duplication intended, not worth refactoring
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
node_obj = getattr_from_fqn(gm, node.target)
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
if len(node.args) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected node.args to have length 1, got {len(node.args)}"
|
||||
)
|
||||
if not isinstance(node.args[0], Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
|
||||
node = node.args[0]
|
||||
return node
|
||||
|
||||
@ -331,7 +356,8 @@ def get_target_type_str(node: Node, gm: GraphModule) -> str:
|
||||
if node.op in ("call_function", "call_method"):
|
||||
target_type = torch.typename(node.target)
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
target_mod = getattr_from_fqn(gm, node.target)
|
||||
target_type = torch.typename(target_mod)
|
||||
return target_type
|
||||
@ -365,7 +391,8 @@ def rekey_logger_info_on_node_name_of_model(
|
||||
for model_name_to_results in result_type_to_results.values():
|
||||
for cur_model_name, list_of_results in model_name_to_results.items():
|
||||
if cur_model_name == model_name:
|
||||
assert len(list_of_results)
|
||||
if len(list_of_results) == 0:
|
||||
raise AssertionError("Expected list_of_results to be not empty")
|
||||
new_layer_name = list_of_results[0]["ref_node_name"]
|
||||
else:
|
||||
continue
|
||||
@ -519,14 +546,20 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
||||
)
|
||||
if norm_args_and_kwargs is not None:
|
||||
norm_args, norm_kwargs = norm_args_and_kwargs
|
||||
assert len(norm_args) + len(norm_kwargs) > idx
|
||||
if len(norm_args) + len(norm_kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(norm_args) + len(norm_kwargs)}"
|
||||
)
|
||||
if idx < len(norm_args):
|
||||
return norm_args[idx]
|
||||
else:
|
||||
# note: in Python 3.7+ dicts are ordered
|
||||
return list(norm_kwargs.values())[idx]
|
||||
else:
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if len(node.args) + len(node.kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
|
||||
)
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
@ -536,7 +569,10 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
||||
# this RuntimeError happens when node argument normalization
|
||||
# requires typehints to proceed, such as for torch.add where
|
||||
# either the first, second or both arguments could be tensors
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if len(node.args) + len(node.kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
|
||||
) from None
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
|
||||
@ -77,7 +77,8 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
||||
res.append(param_value)
|
||||
return res
|
||||
else:
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
||||
if not isinstance(mod, nnqd.LSTM):
|
||||
raise AssertionError(f"type {type(mod)} not handled yet")
|
||||
res = []
|
||||
for weight_value in mod._all_weight_values:
|
||||
res.append(
|
||||
@ -92,10 +93,13 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
||||
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# traverse backwards from the weight arg, accounting for any observers
|
||||
weight_arg_node = node.args[1]
|
||||
assert isinstance(weight_arg_node, Node)
|
||||
if not isinstance(weight_arg_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
|
||||
weight_node = return_first_non_observer_node(weight_arg_node, gm)
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
|
||||
@ -103,8 +107,10 @@ def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# qconv state is arg 1
|
||||
qconv_state_node = node.args[1]
|
||||
assert isinstance(qconv_state_node, Node)
|
||||
assert qconv_state_node.op == "get_attr"
|
||||
if not isinstance(qconv_state_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(qconv_state_node)}")
|
||||
if qconv_state_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {qconv_state_node.op}")
|
||||
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
|
||||
return qconv_state_obj.weight()
|
||||
|
||||
@ -115,34 +121,44 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# weight -> obs -> linear
|
||||
# weight -> to(torch.float16) -> dequantize -> linear
|
||||
linear_second_arg = node.args[1]
|
||||
assert isinstance(linear_second_arg, Node)
|
||||
if not isinstance(linear_second_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(linear_second_arg)}")
|
||||
|
||||
if linear_second_arg.op == "call_module":
|
||||
# weight -> obs -> linear
|
||||
weight_arg_node = node.args[1]
|
||||
assert isinstance(weight_arg_node, Node)
|
||||
if not isinstance(weight_arg_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
|
||||
weight_node = weight_arg_node.args[0]
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
elif linear_second_arg.op == "call_method":
|
||||
# weight -> to(torch.float16) -> dequantize -> linear
|
||||
assert linear_second_arg.op == "call_method"
|
||||
if linear_second_arg.op != "call_method":
|
||||
raise AssertionError(f"Expected call_method, got {linear_second_arg.op}")
|
||||
dequant_node = node.args[1]
|
||||
assert isinstance(dequant_node, Node)
|
||||
if not isinstance(dequant_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(dequant_node)}")
|
||||
to_fp16_node = dequant_node.args[0]
|
||||
assert isinstance(to_fp16_node, Node)
|
||||
if not isinstance(to_fp16_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(to_fp16_node)}")
|
||||
# extract the dtype, so we can cast to it before returning
|
||||
target_dtype = to_fp16_node.args[1]
|
||||
weight_node = to_fp16_node.args[0]
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
# return the weight with fp16 cast
|
||||
return weight.detach().to(target_dtype)
|
||||
else:
|
||||
assert linear_second_arg.op == "get_attr"
|
||||
if linear_second_arg.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {linear_second_arg.op}")
|
||||
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
|
||||
@ -150,8 +166,10 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# packed weight is arg 1
|
||||
packed_weight_node = node.args[1]
|
||||
assert isinstance(packed_weight_node, Node)
|
||||
assert packed_weight_node.op == "get_attr"
|
||||
if not isinstance(packed_weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(packed_weight_node)}")
|
||||
if packed_weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {packed_weight_node.op}")
|
||||
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
|
||||
# TODO(future PR): why does packed_weight.unpack() not work?
|
||||
(weight, _bias), _name = packed_weight.__getstate__()
|
||||
@ -264,7 +282,8 @@ def extract_weight_from_node(
|
||||
|
||||
elif node.op == "call_module":
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
||||
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
||||
|
||||
@ -2268,8 +2268,6 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(set_module_attr(
|
||||
"_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_zendnn", at::hasZenDNN() ? Py_True : Py_False));
|
||||
|
||||
py_module.def("_valgrind_supported_platform", []() {
|
||||
#if defined(USE_VALGRIND)
|
||||
|
||||
@ -3259,7 +3259,7 @@ struct to_ir {
|
||||
case TK_IN:
|
||||
return aten::__contains__;
|
||||
default:
|
||||
throw std::runtime_error("unknown kind " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "unknown kind ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3306,7 +3306,7 @@ struct to_ir {
|
||||
case TK_RSHIFT:
|
||||
return "__rshift__";
|
||||
default:
|
||||
throw std::runtime_error("unknown kind " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "unknown kind ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
@ -4120,8 +4120,7 @@ struct to_ir {
|
||||
} else if (kind == aten::ge) {
|
||||
return aten::le;
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"reverseComparision: unsupported NodeKind. File a bug");
|
||||
TORCH_CHECK(false, "reverseComparision: unsupported NodeKind. File a bug");
|
||||
}
|
||||
|
||||
// any expression that can produce a SugaredValue is handled here
|
||||
|
||||
@ -94,7 +94,7 @@ C10_EXPORT std::string kindToString(int kind) {
|
||||
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
default:
|
||||
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "Unknown kind: ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -167,12 +167,12 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
// Didn't find it. Bake in a constant
|
||||
if (ten.requires_grad()) {
|
||||
pauseTracing();
|
||||
std::ostringstream oss;
|
||||
oss << "Cannot insert a Tensor that requires grad as a constant. "
|
||||
<< "Consider making it a parameter or input, or detaching the gradient\n"
|
||||
<< "Tensor:\n"
|
||||
<< ten;
|
||||
throw std::runtime_error(oss.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot insert a Tensor that requires grad as a constant. ",
|
||||
"Consider making it a parameter or input, or detaching the gradient\n",
|
||||
"Tensor:\n",
|
||||
ten);
|
||||
}
|
||||
|
||||
Value* constant = graph->insertConstant(ten);
|
||||
@ -208,15 +208,19 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
if (var.isFuture()) {
|
||||
oss << "Tried to trace Future or Object that the tracer was not aware of.";
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tried to trace Future or Object that the tracer was not aware of.");
|
||||
} else {
|
||||
oss << "Tried to trace " << var
|
||||
<< " but it is not part of the active trace. Modules that are called during a trace"
|
||||
<< " must be registered as submodules of the thing being traced.";
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tried to trace ",
|
||||
var,
|
||||
" but it is not part of the active trace. Modules that are called during a trace",
|
||||
" must be registered as submodules of the thing being traced.");
|
||||
}
|
||||
throw std::runtime_error(oss.str());
|
||||
|
||||
} else {
|
||||
// If the values are non-tensors, we try to create constants
|
||||
// and bake those constants into the traced graph
|
||||
@ -225,11 +229,12 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
recordSourceLocation(constant.value()->node());
|
||||
return *constant;
|
||||
}
|
||||
std::ostringstream os;
|
||||
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
|
||||
<< "The below value could not be materialized as a constant:\n"
|
||||
<< var;
|
||||
throw std::runtime_error(os.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tracer cannot get value trace for type ",
|
||||
var.tagKind(),
|
||||
". The below value could not be materialized as a constant:\n",
|
||||
var);
|
||||
}
|
||||
}
|
||||
bool TracingState::hasValue(const IValue& var) const {
|
||||
@ -252,15 +257,14 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
|
||||
auto& value_map = getTracingState()->env_stack.back();
|
||||
auto it = value_map.find(iv);
|
||||
if (it == value_map.end()) {
|
||||
std::ostringstream os;
|
||||
os << "output " << i << " (" << var
|
||||
<< ") of traced region did not have observable "
|
||||
<< "data dependence with trace inputs; this probably indicates your "
|
||||
"program "
|
||||
<< "cannot be understood by the tracer.";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
it != value_map.end(),
|
||||
"output ",
|
||||
i,
|
||||
" (",
|
||||
var,
|
||||
") of traced region did not have observable data dependence with trace inputs; ",
|
||||
"this probably indicates your program cannot be understood by the tracer.");
|
||||
return it->second;
|
||||
} else if (iv.isTensorList()) {
|
||||
if (tracing_mode_strict) {
|
||||
@ -281,11 +285,10 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
graph->insertNode(tuple_node);
|
||||
return tuple_node->output();
|
||||
} else if (iv.isGenericDict()) {
|
||||
if (tracing_mode_strict) {
|
||||
throw std::runtime_error(
|
||||
"Encountering a dict at the output of the tracer" +
|
||||
std::string(STRICT_TRACER_MSG));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!tracing_mode_strict,
|
||||
"Encountering a dict at the output of the tracer",
|
||||
STRICT_TRACER_MSG);
|
||||
auto dict = iv.toGenericDict();
|
||||
TypePtr key_type = dict.keyType();
|
||||
TypePtr value_type = dict.valueType();
|
||||
@ -304,15 +307,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!key_type_valid || !value_type_valid) {
|
||||
std::ostringstream os;
|
||||
os << "output " << i << " (" << dict << ") of traced region "
|
||||
<< "cannot be understood by the tracer, only outputs matching"
|
||||
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
|
||||
<< "can be a dictionary output of a traced function";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
key_type_valid && value_type_valid,
|
||||
"output ",
|
||||
i,
|
||||
" (",
|
||||
dict,
|
||||
") of traced region cannot be understood by the tracer, only outputs matching ",
|
||||
"dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] ",
|
||||
"can be a dictionary output of a traced function");
|
||||
std::vector<Value*> keys;
|
||||
std::vector<Value*> values;
|
||||
for (const auto& entry : dict) {
|
||||
@ -598,10 +601,11 @@ void TracingState::setValue(const IValue& v, Value* value) {
|
||||
setValue(entry.value(), static_value);
|
||||
}
|
||||
} else {
|
||||
std::ostringstream os;
|
||||
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
|
||||
<< "Supported types are tensor, tensor list, and tuple of tensors.";
|
||||
throw std::runtime_error(os.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tracer cannot set value trace for type ",
|
||||
v.tagKind(),
|
||||
". Supported types are tensor, tensor list, and tuple of tensors.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -801,11 +805,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
|
||||
recordSourceLocation(info[i]->node());
|
||||
}
|
||||
for (jit::Value* v : info) {
|
||||
if (*v->type() != *jit::IntType::get()) {
|
||||
throw std::runtime_error(
|
||||
"Type mismatch in setposattr for IntArrayRef. Check that your program "
|
||||
"is valid without tracing, and please file a bug report if it is.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
*v->type() == *jit::IntType::get(),
|
||||
"Type mismatch in setposattr for IntArrayRef. Check that your program "
|
||||
"is valid without tracing, and please file a bug report if it is.");
|
||||
}
|
||||
n->addInput(
|
||||
g->insertNode(g->createList(jit::IntType::get(), info))->output());
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
@ -37,10 +38,10 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
return true;
|
||||
}
|
||||
virtual const SourceRange& range() const {
|
||||
throw std::runtime_error("is an Atom");
|
||||
TORCH_CHECK(false, "is an Atom");
|
||||
}
|
||||
virtual const std::string& stringValue() const {
|
||||
throw std::runtime_error("stringValue can only be called on TK_STRING");
|
||||
TORCH_CHECK(false, "stringValue can only be called on TK_STRING");
|
||||
}
|
||||
virtual const TreeList& trees() const {
|
||||
static const TreeList empty_trees = {};
|
||||
@ -79,13 +80,16 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
int lineno,
|
||||
size_t expected_subtrees,
|
||||
bool allow_more) const {
|
||||
if (kind() != k) {
|
||||
std::stringstream ss;
|
||||
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
|
||||
<< "' but found '" << kindToString(kind()) << "'\n";
|
||||
range().highlight(ss);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
kind() == k,
|
||||
filename,
|
||||
":",
|
||||
lineno,
|
||||
": expecting kind '",
|
||||
kindToString(k),
|
||||
"' but found '",
|
||||
kindToString(kind()),
|
||||
"'\n");
|
||||
if (trees().size() < expected_subtrees ||
|
||||
(!allow_more && trees().size() != expected_subtrees)) {
|
||||
std::stringstream ss;
|
||||
@ -93,7 +97,7 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
<< expected_subtrees << " subtrees, but found only " << trees().size()
|
||||
<< "\n";
|
||||
range().highlight(ss);
|
||||
throw std::runtime_error(ss.str());
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
~Tree() override = default;
|
||||
|
||||
@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user