mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CMake] Add functorch target (#83464)
Move functorch/functorch into `functorch` folder - Add functorch/CMakeLists.txt that adds `functorch` native python exension - Modify `setup.py` to package pytorch and functorch together into a single wheel - Modify `functorch.__version__` is not equal to that of `torch.__version__` - Add dummy `functorch/setup.py` file for the projects that still want to build it Differential Revision: [D39058811](https://our.internmc.facebook.com/intern/diff/D39058811) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83464 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
26b5986297
commit
d05a11337c
@ -141,12 +141,6 @@ function checkout_install_torchdynamo() {
|
||||
popd
|
||||
}
|
||||
|
||||
function install_functorch() {
|
||||
pushd functorch
|
||||
time python setup.py develop
|
||||
popd
|
||||
}
|
||||
|
||||
function test_functorch() {
|
||||
python test/run_test.py --functorch --verbose
|
||||
}
|
||||
|
@ -177,7 +177,6 @@ test_dynamo() {
|
||||
}
|
||||
|
||||
if [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
||||
install_functorch
|
||||
test_functorch
|
||||
elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
test_python_shard "${SHARD_NUMBER}"
|
||||
|
@ -180,9 +180,6 @@ test_dynamo_shard() {
|
||||
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
||||
exit 1
|
||||
fi
|
||||
pushd functorch
|
||||
python setup.py develop
|
||||
popd
|
||||
# Temporarily disable test_fx for dynamo pending the investigation on TTS
|
||||
# regression in https://github.com/pytorch/torchdynamo/issues/784
|
||||
time python test/run_test.py \
|
||||
@ -686,7 +683,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
||||
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
||||
test_docs_test
|
||||
elif [[ "${TEST_CONFIG}" == *functorch* ]]; then
|
||||
install_functorch
|
||||
test_functorch
|
||||
else
|
||||
install_torchvision
|
||||
|
@ -144,7 +144,7 @@ python setup.py install --cmake && sccache --show-stats && (
|
||||
if "%BUILD_ENVIRONMENT%"=="" (
|
||||
echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash.
|
||||
) else (
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
if errorlevel 1 exit /b
|
||||
if not errorlevel 0 exit /b
|
||||
|
||||
|
@ -6,15 +6,6 @@ if not errorlevel 0 (
|
||||
exit /b
|
||||
)
|
||||
|
||||
pushd functorch
|
||||
echo "Install functorch"
|
||||
:: --no-deps because for some reason, on windows, `torch` isn't found in
|
||||
:: `pip list` despite being installed. With just `python setup.py develop`,
|
||||
:: setuptools explicitly checks for the existence of torch and can't find it.
|
||||
python setup.py develop --no-deps
|
||||
popd
|
||||
if ERRORLEVEL 1 goto fail
|
||||
|
||||
echo "Installing test dependencies"
|
||||
pip install networkx
|
||||
if errorlevel 1 exit /b
|
||||
|
@ -355,6 +355,8 @@ option(USE_PER_OPERATOR_HEADERS "Whether ATen should generate separate headers f
|
||||
cmake_dependent_option(
|
||||
BUILD_LAZY_TS_BACKEND "Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
||||
"NOT INTERN_BUILD_MOBILE" OFF)
|
||||
cmake_dependent_option(
|
||||
BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||
|
||||
|
||||
if(USE_CCACHE)
|
||||
@ -622,6 +624,7 @@ if(INTERN_BUILD_MOBILE)
|
||||
set(INTERN_DISABLE_AUTOGRAD ON)
|
||||
endif()
|
||||
set(BUILD_PYTHON OFF)
|
||||
set(BUILD_FUNCTORCH OFF)
|
||||
set(BUILD_CAFFE2_OPS OFF)
|
||||
set(USE_DISTRIBUTED OFF)
|
||||
set(NO_API ON)
|
||||
@ -1175,3 +1178,7 @@ caffe2_print_configuration_summary()
|
||||
if(USE_DEPLOY)
|
||||
add_subdirectory(torch/csrc/deploy)
|
||||
endif()
|
||||
|
||||
if(BUILD_FUNCTORCH)
|
||||
add_subdirectory(functorch)
|
||||
endif()
|
||||
|
@ -26,5 +26,6 @@ recursive-include benchmarks *.*
|
||||
recursive-include scripts *.*
|
||||
recursive-include mypy_plugins *.*
|
||||
recursive-include modules *.*
|
||||
recursive-include functorch *.*
|
||||
prune */__pycache__
|
||||
global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp
|
||||
|
@ -1245,6 +1245,8 @@ install(FILES
|
||||
|
||||
# ---[ Torch python bindings build
|
||||
add_subdirectory(../torch torch)
|
||||
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||
|
||||
# ==========================================================
|
||||
# END formerly-libtorch flags
|
||||
|
37
functorch/CMakeLists.txt
Normal file
37
functorch/CMakeLists.txt
Normal file
@ -0,0 +1,37 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(functorch)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
set(FT_DIR csrc)
|
||||
file(GLOB_RECURSE FT_SOURCES ${FT_DIR}/*.cpp)
|
||||
|
||||
add_library(${PROJECT_NAME} MODULE ${FT_SOURCES})
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE FUNCTORCH_BUILD_MAIN_LIB)
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_EXTENSION_NAME=_C)
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE torch torch_python)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE pybind::pybind11)
|
||||
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
|
||||
${CMAKE_BINARY_DIR}/functorch)
|
||||
|
||||
# Copy-pasted prefix/suffix logic for Python extensions from
|
||||
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L1975
|
||||
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L2022
|
||||
# TODO: It would be good to be able to use Python3_add_library target, but it does not work in many cases
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" DEBUG_POSTFIX "")
|
||||
if(WIN32)
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".pyd")
|
||||
else()
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".so")
|
||||
endif()
|
||||
# Needed to link functorch on MacOS
|
||||
if(NOT ${TORCH_PYTHON_LINK_FLAGS} STREQUAL "")
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})
|
||||
endif()
|
||||
install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}")
|
@ -32,7 +32,4 @@ from ._src.make_functional import (
|
||||
FunctionalModuleWithBuffers,
|
||||
)
|
||||
|
||||
try:
|
||||
from .version import __version__ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
__version__ = torch.__version__
|
@ -1,18 +0,0 @@
|
||||
[bdist_wheel]
|
||||
universal=1
|
||||
|
||||
[metadata]
|
||||
license_file = LICENSE
|
||||
|
||||
[pep8]
|
||||
max-line-length = 120
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
exclude = docs, benchmarks, notebooks, tools
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
functorch/_src/decompositions.py: E501
|
||||
|
||||
[pydocstyle]
|
||||
select = D417 # Missing argument descriptions in the docstring
|
@ -3,18 +3,11 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This is a dummy setup.py that does not do anything
|
||||
|
||||
import distutils.command.clean
|
||||
import sys
|
||||
import shutil
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import (
|
||||
CppExtension,
|
||||
BuildExtension,
|
||||
)
|
||||
from setuptools import setup
|
||||
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
version_txt = os.path.join(cwd, 'version.txt')
|
||||
@ -33,16 +26,6 @@ elif sha != 'Unknown':
|
||||
version += '+' + sha[:7]
|
||||
|
||||
|
||||
def write_version_file():
|
||||
version_path = os.path.join(cwd, 'functorch', 'version.py')
|
||||
with open(version_path, 'w') as f:
|
||||
f.write("__version__ = '{}'\n".format(version))
|
||||
f.write("git_version = {}\n".format(repr(sha)))
|
||||
|
||||
|
||||
# pytorch_dep = 'torch'
|
||||
# if os.getenv('PYTORCH_VERSION'):
|
||||
# pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
|
||||
requirements = [
|
||||
# This represents a nightly version of PyTorch.
|
||||
# It can be installed as a binary or from source.
|
||||
@ -53,83 +36,7 @@ extras = {}
|
||||
extras["aot"] = ["networkx", ]
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
|
||||
with open(".gitignore", "r") as f:
|
||||
ignores = f.read()
|
||||
for wildcard in filter(None, ignores.split("\n")):
|
||||
for filename in glob.glob(wildcard):
|
||||
try:
|
||||
os.remove(filename)
|
||||
except OSError:
|
||||
shutil.rmtree(filename, ignore_errors=True)
|
||||
|
||||
# It's an old-style class in Python 2.7...
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
|
||||
def get_extensions():
|
||||
extension = CppExtension
|
||||
|
||||
# See functorch/csrc/Macros.h
|
||||
define_macros = [('FUNCTORCH_BUILD_MAIN_LIB', None)]
|
||||
|
||||
extra_link_args = []
|
||||
extra_compile_args = {"cxx": [
|
||||
"-O3",
|
||||
"-std=c++14",
|
||||
"-fdiagnostics-color=always",
|
||||
]}
|
||||
debug_mode = os.getenv('DEBUG', '0') == '1'
|
||||
if debug_mode:
|
||||
print("Compiling in debug mode")
|
||||
extra_compile_args = {
|
||||
"cxx": [
|
||||
"-O0",
|
||||
"-fno-inline",
|
||||
"-g",
|
||||
"-std=c++14",
|
||||
"-fdiagnostics-color=always",
|
||||
]}
|
||||
extra_link_args = ["-O0", "-g"]
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extensions_dir = os.path.join(this_dir, "functorch", "csrc")
|
||||
|
||||
extension_sources = set(
|
||||
os.path.join(extensions_dir, p)
|
||||
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
||||
)
|
||||
sources = list(extension_sources)
|
||||
sources.append(os.path.join(extensions_dir, "dim", "dim.cpp"))
|
||||
|
||||
ext_modules = [
|
||||
extension(
|
||||
"functorch._C",
|
||||
sources,
|
||||
include_dirs=[this_dir],
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=extra_link_args,
|
||||
)
|
||||
]
|
||||
|
||||
return ext_modules
|
||||
|
||||
|
||||
class BuildExtension_(BuildExtension):
|
||||
def build_extensions(self, *args, **kwargs):
|
||||
# It turns out for windows this isn't populated?
|
||||
if hasattr(self.compiler, 'compiler_so'):
|
||||
if '-Wstrict-prototypes' in self.compiler.compiler_so:
|
||||
self.compiler.compiler_so.remove('-Wstrict-prototypes')
|
||||
super().build_extensions(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Building wheel {}-{}".format(package_name, version))
|
||||
write_version_file()
|
||||
try:
|
||||
setup(
|
||||
# Metadata
|
||||
@ -141,14 +48,10 @@ if __name__ == '__main__':
|
||||
license='BSD',
|
||||
|
||||
# Package info
|
||||
packages=find_packages(),
|
||||
packages=[],
|
||||
install_requires=requirements,
|
||||
extras_require=extras,
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
||||
'clean': clean,
|
||||
})
|
||||
)
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
23
setup.py
23
setup.py
@ -614,6 +614,23 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
os.makedirs(dst_dir)
|
||||
self.copy_file(src, dst)
|
||||
i += 1
|
||||
|
||||
# Copy functorch extension
|
||||
for i, ext in enumerate(self.extensions):
|
||||
if ext.name != "functorch._C":
|
||||
continue
|
||||
fullname = self.get_ext_fullname(ext.name)
|
||||
filename = self.get_ext_filename(fullname)
|
||||
fileext = os.path.splitext(filename)[1]
|
||||
src = os.path.join(os.path.dirname(filename), "functorch" + fileext)
|
||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||
if os.path.exists(src):
|
||||
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
self.copy_file(src, dst)
|
||||
|
||||
setuptools.command.build_ext.build_ext.build_extensions(self)
|
||||
|
||||
|
||||
@ -893,6 +910,12 @@ def configure_extension_build():
|
||||
name=str('caffe2.python.caffe2_pybind11_state_hip'),
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['BUILD_FUNCTORCH']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('functorch._C'),
|
||||
sources=[]),
|
||||
)
|
||||
|
||||
cmdclass = {
|
||||
'bdist_wheel': wheel_concatenate,
|
||||
|
@ -143,19 +143,19 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
view_default = torch.ops.aten.view.default(clone_default, [-1])
|
||||
view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
|
||||
select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None
|
||||
view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
|
||||
view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None
|
||||
select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
|
||||
view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None
|
||||
view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None
|
||||
view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
|
||||
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None
|
||||
return view_default_5
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
view = torch.ops.aten.view.default(clone, [-1])
|
||||
view_1 = torch.ops.aten.view.default(clone, [-1])
|
||||
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(select, [-1]); select = None
|
||||
add = torch.ops.aten.add_.Tensor(view_2, 1)
|
||||
view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None
|
||||
select_1 = torch.ops.aten.select.int(view_3, 0, 0)
|
||||
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None
|
||||
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
|
||||
view_6 = torch.ops.aten.view.default(view_5, [-1])
|
||||
add_1 = torch.ops.aten.add_.Tensor(view_5, view_6); view_6 = None
|
||||
return view_5
|
||||
""")
|
||||
|
||||
def test_reinplace_scatter_twice(self):
|
||||
@ -180,14 +180,14 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
|
||||
slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||
select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None
|
||||
return clone_default
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
|
||||
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None
|
||||
return clone
|
||||
""")
|
||||
|
||||
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
|
||||
@ -319,8 +319,8 @@ def forward(self, a__1):
|
||||
|
||||
def forward(self):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None
|
||||
diagonal = torch.ops.aten.diagonal.default(zeros)
|
||||
add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = None
|
||||
return [zeros]
|
||||
""")
|
||||
|
||||
@ -343,11 +343,11 @@ def forward(self):
|
||||
def forward(self):
|
||||
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
||||
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
|
||||
slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807); slice_tensor_2 = None
|
||||
copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones); slice_tensor_3 = ones = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
|
||||
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(slice_3, 1, 2, 9223372036854775807); slice_3 = None
|
||||
copy__default = torch.ops.aten.copy_.default(slice_tensor, ones); slice_tensor = ones = None
|
||||
return zeros
|
||||
""")
|
||||
|
||||
|
@ -496,3 +496,6 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
# Pybind11 requires explicit linking of the torch_python library
|
||||
target_link_libraries(nnapi_backend PRIVATE torch torch_python pybind::pybind11)
|
||||
endif()
|
||||
|
||||
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||
|
Reference in New Issue
Block a user