mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CMake] Add functorch target (#83464)
Move functorch/functorch into `functorch` folder - Add functorch/CMakeLists.txt that adds `functorch` native python exension - Modify `setup.py` to package pytorch and functorch together into a single wheel - Modify `functorch.__version__` is not equal to that of `torch.__version__` - Add dummy `functorch/setup.py` file for the projects that still want to build it Differential Revision: [D39058811](https://our.internmc.facebook.com/intern/diff/D39058811) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83464 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
26b5986297
commit
d05a11337c
@ -141,12 +141,6 @@ function checkout_install_torchdynamo() {
|
|||||||
popd
|
popd
|
||||||
}
|
}
|
||||||
|
|
||||||
function install_functorch() {
|
|
||||||
pushd functorch
|
|
||||||
time python setup.py develop
|
|
||||||
popd
|
|
||||||
}
|
|
||||||
|
|
||||||
function test_functorch() {
|
function test_functorch() {
|
||||||
python test/run_test.py --functorch --verbose
|
python test/run_test.py --functorch --verbose
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,6 @@ test_dynamo() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
if [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
||||||
install_functorch
|
|
||||||
test_functorch
|
test_functorch
|
||||||
elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then
|
elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||||
test_python_shard "${SHARD_NUMBER}"
|
test_python_shard "${SHARD_NUMBER}"
|
||||||
|
@ -180,9 +180,6 @@ test_dynamo_shard() {
|
|||||||
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
pushd functorch
|
|
||||||
python setup.py develop
|
|
||||||
popd
|
|
||||||
# Temporarily disable test_fx for dynamo pending the investigation on TTS
|
# Temporarily disable test_fx for dynamo pending the investigation on TTS
|
||||||
# regression in https://github.com/pytorch/torchdynamo/issues/784
|
# regression in https://github.com/pytorch/torchdynamo/issues/784
|
||||||
time python test/run_test.py \
|
time python test/run_test.py \
|
||||||
@ -686,7 +683,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
|||||||
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
||||||
test_docs_test
|
test_docs_test
|
||||||
elif [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
elif [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
||||||
install_functorch
|
|
||||||
test_functorch
|
test_functorch
|
||||||
else
|
else
|
||||||
install_torchvision
|
install_torchvision
|
||||||
|
@ -144,7 +144,7 @@ python setup.py install --cmake && sccache --show-stats && (
|
|||||||
if "%BUILD_ENVIRONMENT%"=="" (
|
if "%BUILD_ENVIRONMENT%"=="" (
|
||||||
echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash.
|
echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash.
|
||||||
) else (
|
) else (
|
||||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||||
if errorlevel 1 exit /b
|
if errorlevel 1 exit /b
|
||||||
if not errorlevel 0 exit /b
|
if not errorlevel 0 exit /b
|
||||||
|
|
||||||
|
@ -6,15 +6,6 @@ if not errorlevel 0 (
|
|||||||
exit /b
|
exit /b
|
||||||
)
|
)
|
||||||
|
|
||||||
pushd functorch
|
|
||||||
echo "Install functorch"
|
|
||||||
:: --no-deps because for some reason, on windows, `torch` isn't found in
|
|
||||||
:: `pip list` despite being installed. With just `python setup.py develop`,
|
|
||||||
:: setuptools explicitly checks for the existence of torch and can't find it.
|
|
||||||
python setup.py develop --no-deps
|
|
||||||
popd
|
|
||||||
if ERRORLEVEL 1 goto fail
|
|
||||||
|
|
||||||
echo "Installing test dependencies"
|
echo "Installing test dependencies"
|
||||||
pip install networkx
|
pip install networkx
|
||||||
if errorlevel 1 exit /b
|
if errorlevel 1 exit /b
|
||||||
|
@ -355,6 +355,8 @@ option(USE_PER_OPERATOR_HEADERS "Whether ATen should generate separate headers f
|
|||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
BUILD_LAZY_TS_BACKEND "Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
BUILD_LAZY_TS_BACKEND "Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
||||||
"NOT INTERN_BUILD_MOBILE" OFF)
|
"NOT INTERN_BUILD_MOBILE" OFF)
|
||||||
|
cmake_dependent_option(
|
||||||
|
BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||||
|
|
||||||
|
|
||||||
if(USE_CCACHE)
|
if(USE_CCACHE)
|
||||||
@ -622,6 +624,7 @@ if(INTERN_BUILD_MOBILE)
|
|||||||
set(INTERN_DISABLE_AUTOGRAD ON)
|
set(INTERN_DISABLE_AUTOGRAD ON)
|
||||||
endif()
|
endif()
|
||||||
set(BUILD_PYTHON OFF)
|
set(BUILD_PYTHON OFF)
|
||||||
|
set(BUILD_FUNCTORCH OFF)
|
||||||
set(BUILD_CAFFE2_OPS OFF)
|
set(BUILD_CAFFE2_OPS OFF)
|
||||||
set(USE_DISTRIBUTED OFF)
|
set(USE_DISTRIBUTED OFF)
|
||||||
set(NO_API ON)
|
set(NO_API ON)
|
||||||
@ -1175,3 +1178,7 @@ caffe2_print_configuration_summary()
|
|||||||
if(USE_DEPLOY)
|
if(USE_DEPLOY)
|
||||||
add_subdirectory(torch/csrc/deploy)
|
add_subdirectory(torch/csrc/deploy)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(BUILD_FUNCTORCH)
|
||||||
|
add_subdirectory(functorch)
|
||||||
|
endif()
|
||||||
|
@ -26,5 +26,6 @@ recursive-include benchmarks *.*
|
|||||||
recursive-include scripts *.*
|
recursive-include scripts *.*
|
||||||
recursive-include mypy_plugins *.*
|
recursive-include mypy_plugins *.*
|
||||||
recursive-include modules *.*
|
recursive-include modules *.*
|
||||||
|
recursive-include functorch *.*
|
||||||
prune */__pycache__
|
prune */__pycache__
|
||||||
global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp
|
global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp
|
||||||
|
@ -1245,6 +1245,8 @@ install(FILES
|
|||||||
|
|
||||||
# ---[ Torch python bindings build
|
# ---[ Torch python bindings build
|
||||||
add_subdirectory(../torch torch)
|
add_subdirectory(../torch torch)
|
||||||
|
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||||
|
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||||
|
|
||||||
# ==========================================================
|
# ==========================================================
|
||||||
# END formerly-libtorch flags
|
# END formerly-libtorch flags
|
||||||
|
37
functorch/CMakeLists.txt
Normal file
37
functorch/CMakeLists.txt
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.12)
|
||||||
|
project(functorch)
|
||||||
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
|
set(FT_DIR csrc)
|
||||||
|
file(GLOB_RECURSE FT_SOURCES ${FT_DIR}/*.cpp)
|
||||||
|
|
||||||
|
add_library(${PROJECT_NAME} MODULE ${FT_SOURCES})
|
||||||
|
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
target_compile_definitions(${PROJECT_NAME} PRIVATE FUNCTORCH_BUILD_MAIN_LIB)
|
||||||
|
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_EXTENSION_NAME=_C)
|
||||||
|
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
|
||||||
|
target_compile_options(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE torch torch_python)
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE pybind::pybind11)
|
||||||
|
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
|
||||||
|
${CMAKE_BINARY_DIR}/functorch)
|
||||||
|
|
||||||
|
# Copy-pasted prefix/suffix logic for Python extensions from
|
||||||
|
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L1975
|
||||||
|
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L2022
|
||||||
|
# TODO: It would be good to be able to use Python3_add_library target, but it does not work in many cases
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" DEBUG_POSTFIX "")
|
||||||
|
if(WIN32)
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".pyd")
|
||||||
|
else()
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".so")
|
||||||
|
endif()
|
||||||
|
# Needed to link functorch on MacOS
|
||||||
|
if(NOT ${TORCH_PYTHON_LINK_FLAGS} STREQUAL "")
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})
|
||||||
|
endif()
|
||||||
|
install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}")
|
@ -32,7 +32,4 @@ from ._src.make_functional import (
|
|||||||
FunctionalModuleWithBuffers,
|
FunctionalModuleWithBuffers,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
__version__ = torch.__version__
|
||||||
from .version import __version__ # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
@ -1,18 +0,0 @@
|
|||||||
[bdist_wheel]
|
|
||||||
universal=1
|
|
||||||
|
|
||||||
[metadata]
|
|
||||||
license_file = LICENSE
|
|
||||||
|
|
||||||
[pep8]
|
|
||||||
max-line-length = 120
|
|
||||||
|
|
||||||
[flake8]
|
|
||||||
max-line-length = 120
|
|
||||||
exclude = docs, benchmarks, notebooks, tools
|
|
||||||
per-file-ignores =
|
|
||||||
__init__.py: F401
|
|
||||||
functorch/_src/decompositions.py: E501
|
|
||||||
|
|
||||||
[pydocstyle]
|
|
||||||
select = D417 # Missing argument descriptions in the docstring
|
|
@ -3,18 +3,11 @@
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# This is a dummy setup.py that does not do anything
|
||||||
|
|
||||||
import distutils.command.clean
|
|
||||||
import sys
|
|
||||||
import shutil
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import (
|
|
||||||
CppExtension,
|
|
||||||
BuildExtension,
|
|
||||||
)
|
|
||||||
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
version_txt = os.path.join(cwd, 'version.txt')
|
version_txt = os.path.join(cwd, 'version.txt')
|
||||||
@ -33,16 +26,6 @@ elif sha != 'Unknown':
|
|||||||
version += '+' + sha[:7]
|
version += '+' + sha[:7]
|
||||||
|
|
||||||
|
|
||||||
def write_version_file():
|
|
||||||
version_path = os.path.join(cwd, 'functorch', 'version.py')
|
|
||||||
with open(version_path, 'w') as f:
|
|
||||||
f.write("__version__ = '{}'\n".format(version))
|
|
||||||
f.write("git_version = {}\n".format(repr(sha)))
|
|
||||||
|
|
||||||
|
|
||||||
# pytorch_dep = 'torch'
|
|
||||||
# if os.getenv('PYTORCH_VERSION'):
|
|
||||||
# pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
|
|
||||||
requirements = [
|
requirements = [
|
||||||
# This represents a nightly version of PyTorch.
|
# This represents a nightly version of PyTorch.
|
||||||
# It can be installed as a binary or from source.
|
# It can be installed as a binary or from source.
|
||||||
@ -53,83 +36,7 @@ extras = {}
|
|||||||
extras["aot"] = ["networkx", ]
|
extras["aot"] = ["networkx", ]
|
||||||
|
|
||||||
|
|
||||||
class clean(distutils.command.clean.clean):
|
|
||||||
def run(self):
|
|
||||||
|
|
||||||
with open(".gitignore", "r") as f:
|
|
||||||
ignores = f.read()
|
|
||||||
for wildcard in filter(None, ignores.split("\n")):
|
|
||||||
for filename in glob.glob(wildcard):
|
|
||||||
try:
|
|
||||||
os.remove(filename)
|
|
||||||
except OSError:
|
|
||||||
shutil.rmtree(filename, ignore_errors=True)
|
|
||||||
|
|
||||||
# It's an old-style class in Python 2.7...
|
|
||||||
distutils.command.clean.clean.run(self)
|
|
||||||
|
|
||||||
|
|
||||||
def get_extensions():
|
|
||||||
extension = CppExtension
|
|
||||||
|
|
||||||
# See functorch/csrc/Macros.h
|
|
||||||
define_macros = [('FUNCTORCH_BUILD_MAIN_LIB', None)]
|
|
||||||
|
|
||||||
extra_link_args = []
|
|
||||||
extra_compile_args = {"cxx": [
|
|
||||||
"-O3",
|
|
||||||
"-std=c++14",
|
|
||||||
"-fdiagnostics-color=always",
|
|
||||||
]}
|
|
||||||
debug_mode = os.getenv('DEBUG', '0') == '1'
|
|
||||||
if debug_mode:
|
|
||||||
print("Compiling in debug mode")
|
|
||||||
extra_compile_args = {
|
|
||||||
"cxx": [
|
|
||||||
"-O0",
|
|
||||||
"-fno-inline",
|
|
||||||
"-g",
|
|
||||||
"-std=c++14",
|
|
||||||
"-fdiagnostics-color=always",
|
|
||||||
]}
|
|
||||||
extra_link_args = ["-O0", "-g"]
|
|
||||||
|
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
extensions_dir = os.path.join(this_dir, "functorch", "csrc")
|
|
||||||
|
|
||||||
extension_sources = set(
|
|
||||||
os.path.join(extensions_dir, p)
|
|
||||||
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
|
||||||
)
|
|
||||||
sources = list(extension_sources)
|
|
||||||
sources.append(os.path.join(extensions_dir, "dim", "dim.cpp"))
|
|
||||||
|
|
||||||
ext_modules = [
|
|
||||||
extension(
|
|
||||||
"functorch._C",
|
|
||||||
sources,
|
|
||||||
include_dirs=[this_dir],
|
|
||||||
define_macros=define_macros,
|
|
||||||
extra_compile_args=extra_compile_args,
|
|
||||||
extra_link_args=extra_link_args,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return ext_modules
|
|
||||||
|
|
||||||
|
|
||||||
class BuildExtension_(BuildExtension):
|
|
||||||
def build_extensions(self, *args, **kwargs):
|
|
||||||
# It turns out for windows this isn't populated?
|
|
||||||
if hasattr(self.compiler, 'compiler_so'):
|
|
||||||
if '-Wstrict-prototypes' in self.compiler.compiler_so:
|
|
||||||
self.compiler.compiler_so.remove('-Wstrict-prototypes')
|
|
||||||
super().build_extensions(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print("Building wheel {}-{}".format(package_name, version))
|
|
||||||
write_version_file()
|
|
||||||
try:
|
try:
|
||||||
setup(
|
setup(
|
||||||
# Metadata
|
# Metadata
|
||||||
@ -141,14 +48,10 @@ if __name__ == '__main__':
|
|||||||
license='BSD',
|
license='BSD',
|
||||||
|
|
||||||
# Package info
|
# Package info
|
||||||
packages=find_packages(),
|
packages=[],
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
ext_modules=get_extensions(),
|
)
|
||||||
cmdclass={
|
|
||||||
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
|
||||||
'clean': clean,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
23
setup.py
23
setup.py
@ -614,6 +614,23 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
os.makedirs(dst_dir)
|
os.makedirs(dst_dir)
|
||||||
self.copy_file(src, dst)
|
self.copy_file(src, dst)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
# Copy functorch extension
|
||||||
|
for i, ext in enumerate(self.extensions):
|
||||||
|
if ext.name != "functorch._C":
|
||||||
|
continue
|
||||||
|
fullname = self.get_ext_fullname(ext.name)
|
||||||
|
filename = self.get_ext_filename(fullname)
|
||||||
|
fileext = os.path.splitext(filename)[1]
|
||||||
|
src = os.path.join(os.path.dirname(filename), "functorch" + fileext)
|
||||||
|
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||||
|
if os.path.exists(src):
|
||||||
|
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||||
|
dst_dir = os.path.dirname(dst)
|
||||||
|
if not os.path.exists(dst_dir):
|
||||||
|
os.makedirs(dst_dir)
|
||||||
|
self.copy_file(src, dst)
|
||||||
|
|
||||||
setuptools.command.build_ext.build_ext.build_extensions(self)
|
setuptools.command.build_ext.build_ext.build_extensions(self)
|
||||||
|
|
||||||
|
|
||||||
@ -893,6 +910,12 @@ def configure_extension_build():
|
|||||||
name=str('caffe2.python.caffe2_pybind11_state_hip'),
|
name=str('caffe2.python.caffe2_pybind11_state_hip'),
|
||||||
sources=[]),
|
sources=[]),
|
||||||
)
|
)
|
||||||
|
if cmake_cache_vars['BUILD_FUNCTORCH']:
|
||||||
|
extensions.append(
|
||||||
|
Extension(
|
||||||
|
name=str('functorch._C'),
|
||||||
|
sources=[]),
|
||||||
|
)
|
||||||
|
|
||||||
cmdclass = {
|
cmdclass = {
|
||||||
'bdist_wheel': wheel_concatenate,
|
'bdist_wheel': wheel_concatenate,
|
||||||
|
@ -143,19 +143,19 @@ def forward(self, a__1):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, a__1):
|
def forward(self, a__1):
|
||||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
view_default = torch.ops.aten.view.default(clone_default, [-1])
|
view = torch.ops.aten.view.default(clone, [-1])
|
||||||
view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
|
view_1 = torch.ops.aten.view.default(clone, [-1])
|
||||||
select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None
|
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
|
||||||
view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None
|
view_2 = torch.ops.aten.view.default(select, [-1]); select = None
|
||||||
add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
|
add = torch.ops.aten.add_.Tensor(view_2, 1)
|
||||||
view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None
|
view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None
|
||||||
select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
|
select_1 = torch.ops.aten.select.int(view_3, 0, 0)
|
||||||
view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None
|
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None
|
||||||
view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None
|
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
|
||||||
view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
|
view_6 = torch.ops.aten.view.default(view_5, [-1])
|
||||||
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None
|
add_1 = torch.ops.aten.add_.Tensor(view_5, view_6); view_6 = None
|
||||||
return view_default_5
|
return view_5
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_reinplace_scatter_twice(self):
|
def test_reinplace_scatter_twice(self):
|
||||||
@ -180,14 +180,14 @@ def forward(self, a__1):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, a__1):
|
def forward(self, a__1):
|
||||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||||
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||||
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
|
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
|
||||||
slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||||
select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None
|
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None
|
||||||
return clone_default
|
return clone
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
|
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
|
||||||
@ -319,8 +319,8 @@ def forward(self, a__1):
|
|||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||||
diagonal_default = torch.ops.aten.diagonal.default(zeros)
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
||||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None
|
add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = None
|
||||||
return [zeros]
|
return [zeros]
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -343,11 +343,11 @@ def forward(self):
|
|||||||
def forward(self):
|
def forward(self):
|
||||||
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
||||||
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
||||||
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||||
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
|
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
|
||||||
slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||||
slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807); slice_tensor_2 = None
|
slice_tensor = torch.ops.aten.slice.Tensor(slice_3, 1, 2, 9223372036854775807); slice_3 = None
|
||||||
copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones); slice_tensor_3 = ones = None
|
copy__default = torch.ops.aten.copy_.default(slice_tensor, ones); slice_tensor = ones = None
|
||||||
return zeros
|
return zeros
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
@ -496,3 +496,6 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
# Pybind11 requires explicit linking of the torch_python library
|
# Pybind11 requires explicit linking of the torch_python library
|
||||||
target_link_libraries(nnapi_backend PRIVATE torch torch_python pybind::pybind11)
|
target_link_libraries(nnapi_backend PRIVATE torch torch_python pybind::pybind11)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||||
|
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||||
|
Reference in New Issue
Block a user