mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch::deploy, an embedded torch-python interpreter (#50458)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50458 libinterpreter.so contains a frozen python distribution including torch-python bindings. Freezing refers to serializing bytecode of python standard library modules as well as the torch python library and embedding them in the library code. This library can then be dlopened multiple times in one process context, each interpreter having its own python state and GIL. In addition, each python environment is sealed off from the filesystem and can only import the frozen modules included in the distribution. This change relies on newly added frozenpython, a cpython 3.8.6 fork built for this purpose. Frozenpython provides libpython3.8-frozen.a which contains frozen bytecode and object code for the python standard library. Building on top of frozen python, the frozen torch-python bindings are added in this diff, providing each embedded interpreter with a copy of the torch bindings. Each interpreter is intended to share one instance of libtorch and the underlying tensor libraries. Known issues - Autograd is not expected to work with the embedded interpreter currently, as it manages its own python interactions and needs to coordinate with the duplicated python states in each of the interpreters. - Distributed and cuda stuff is disabled in libinterpreter.so build, needs to be revisited - __file__ is not supported in the context of embedded python since there are no files for the underlying library modules. using __file__ - __version__ is not properly supported in the embedded torch-python, just a workaround for now Test Plan: tested locally and on CI with cmake and buck builds running torch::deploy interpreter_test Reviewed By: ailzhang Differential Revision: D25850783 fbshipit-source-id: a4656377caff25b73913daae7ae2f88bcab8fd88
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ddf26816d3
commit
3192f9e4fe
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@ -170,6 +170,8 @@ jobs:
|
||||
# FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed
|
||||
# in a follow up PR.
|
||||
# /torch/csrc/generic/*.cpp is excluded because those files aren't actually built.
|
||||
# deploy/interpreter files are excluded due to using macros and other techniquies
|
||||
# that are not easily converted to accepted c++
|
||||
python tools/clang_tidy.py \
|
||||
--verbose \
|
||||
--paths torch/csrc/ \
|
||||
@ -186,6 +188,10 @@ jobs:
|
||||
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
|
||||
-g"-torch/csrc/generic/*.cpp" \
|
||||
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
|
||||
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
|
||||
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
|
||||
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
|
||||
-g"-torch/csrc/deploy/interpreter/test_main.cpp" \
|
||||
"$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt
|
||||
|
||||
cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -66,6 +66,9 @@ torch/csrc/autograd/generated/*
|
||||
torch/testing/_internal/generated/annotated_fn_args.py
|
||||
torch/testing/_internal/data/*.pt
|
||||
torch/csrc/cudnn/cuDNN.cpp
|
||||
torch/csrc/deploy/interpreter/cpython
|
||||
torch/csrc/deploy/interpreter/frozen
|
||||
torch/csrc/deploy/interpreter/third_party/typing_extensions.py
|
||||
torch/csrc/generated
|
||||
torch/csrc/generic/TensorMethods.cpp
|
||||
torch/csrc/jit/generated/*
|
||||
|
@ -23,6 +23,17 @@ if [[ "$BUILD_ENVIRONMENT" == *-mobile-code-analysis* ]]; then
|
||||
exec "$(dirname "${BASH_SOURCE[0]}")/build-mobile-code-analysis.sh" "$@"
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *linux-xenial-cuda10.2-cudnn7-py3-gcc7* ]]; then
|
||||
# Enabling DEPLOY build (embedded torch python interpreter, experimental)
|
||||
# only on one config for now, can expand later
|
||||
export USE_DEPLOY=ON
|
||||
|
||||
# Deploy feature builds cpython. It requires these packages.
|
||||
# TODO move this to dockerfile?
|
||||
sudo apt-get -qq update
|
||||
sudo apt-get -qq install libffi-dev libbz2-dev libreadline-dev libncurses5-dev libncursesw5-dev libgdbm-dev libsqlite3-dev uuid-dev tk-dev
|
||||
fi
|
||||
|
||||
echo "Python version:"
|
||||
python --version
|
||||
|
||||
|
@ -354,6 +354,11 @@ test_vec256() {
|
||||
fi
|
||||
}
|
||||
|
||||
test_torch_deploy() {
|
||||
SIMPLE_MODEL_PATH=torch/csrc/deploy/example/simple.pt LIBINTERPRETER_PATH=build/lib/libinterpreter.so build/bin/interpreter_test
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -371,6 +376,9 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
|
||||
# TODO: run some C++ tests
|
||||
echo "no-op at the moment"
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then
|
||||
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1 ]]; then
|
||||
test_torch_deploy
|
||||
fi
|
||||
install_torchvision
|
||||
test_python_shard1
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then
|
||||
|
@ -919,3 +919,8 @@ endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
# ---[ Torch Deploy
|
||||
if(USE_DEPLOY)
|
||||
add_subdirectory(torch/csrc/deploy)
|
||||
endif()
|
||||
|
@ -22,7 +22,11 @@ if sys.version_info < (3,):
|
||||
from ._utils import _import_dotted_name
|
||||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
|
||||
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS
|
||||
from .version import __version__
|
||||
# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
|
||||
if sys.executable == 'torch_deploy':
|
||||
__version__ = "torch-deploy-1.8"
|
||||
else:
|
||||
from .version import __version__
|
||||
from ._six import string_classes as _string_classes
|
||||
|
||||
from typing import Set, Type, TYPE_CHECKING
|
||||
@ -134,7 +138,7 @@ if sys.platform == 'win32':
|
||||
|
||||
# See Note [Global dependencies]
|
||||
def _load_global_deps():
|
||||
if platform.system() == 'Windows':
|
||||
if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
|
||||
return
|
||||
|
||||
lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so')
|
||||
@ -516,7 +520,7 @@ from ._tensor_str import set_printoptions
|
||||
################################################################################
|
||||
|
||||
def manager_path():
|
||||
if platform.system() == 'Windows':
|
||||
if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
|
||||
return b""
|
||||
path = get_file_path('torch', 'bin', 'torch_shm_manager')
|
||||
prepare_multiprocessing_environment(get_file_path('torch'))
|
||||
|
@ -2,7 +2,6 @@ import torch._C
|
||||
|
||||
import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
@ -67,7 +66,7 @@ class _OpNamespace(types.ModuleType):
|
||||
return op
|
||||
|
||||
class _Ops(types.ModuleType):
|
||||
__file__ = os.path.join(os.path.dirname(__file__), '_ops.py')
|
||||
__file__ = '_ops.py'
|
||||
|
||||
def __init__(self):
|
||||
super(_Ops, self).__init__('torch.ops')
|
||||
|
@ -1,6 +1,7 @@
|
||||
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# this arbitrary-looking assortment of functionality is provided here
|
||||
@ -8,11 +9,16 @@ import tempfile
|
||||
# use is the FB build environment, where this source file is replaced
|
||||
# by an equivalent.
|
||||
|
||||
if os.path.basename(os.path.dirname(__file__)) == 'shared':
|
||||
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
if sys.executable == 'torch_deploy':
|
||||
# __file__ is meaningless in the context of frozen torch used in torch deploy.
|
||||
# setting empty torch_parent should allow below functions to operate without crashing,
|
||||
# but it's unclear if there is a valid use case for them in the context of deploy.
|
||||
torch_parent = ""
|
||||
else:
|
||||
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
if os.path.basename(os.path.dirname(__file__)) == 'shared':
|
||||
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
else:
|
||||
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
def get_file_path(*path_components):
|
||||
return os.path.join(torch_parent, *path_components)
|
||||
@ -60,7 +66,7 @@ def get_source_lines_and_file(obj, error_msg=None):
|
||||
|
||||
TEST_MASTER_ADDR = '127.0.0.1'
|
||||
TEST_MASTER_PORT = 29500
|
||||
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
||||
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
||||
# libtorch_global_deps, see Note [Global dependencies]
|
||||
USE_GLOBAL_DEPS = True
|
||||
# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
|
||||
|
@ -692,6 +692,8 @@ extern "C"
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
TORCH_API PyObject* initModule();
|
||||
// separate decl and defn for msvc error C2491
|
||||
PyObject* initModule() {
|
||||
HANDLE_TH_ERRORS
|
||||
at::internal::lazy_init_num_threads();
|
||||
|
1
torch/csrc/deploy/.gitignore
vendored
Normal file
1
torch/csrc/deploy/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
example/generated/*
|
3
torch/csrc/deploy/CMakeLists.txt
Normal file
3
torch/csrc/deploy/CMakeLists.txt
Normal file
@ -0,0 +1,3 @@
|
||||
set(DEPLOY_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
||||
add_subdirectory(interpreter)
|
10
torch/csrc/deploy/README.md
Normal file
10
torch/csrc/deploy/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Torch Deploy
|
||||
This is an experimental feature to embed multiple python interpreters inside the torch library,
|
||||
providing a solution to the 'GIL problem' for multithreading with the convenience of python
|
||||
and eager or torchscripted pytorch programs.
|
||||
|
||||
# libinterpreter
|
||||
This is an internal library used behind the scenes to enable multiple python interpreters in
|
||||
a single deploy runtime. libinterpreter.so is DLOPENed multiple times by the deploy library.
|
||||
Each copy of libinterpreter exposes a simple interpreter interface but hides its python and other
|
||||
internal symbols, preventing the different python instances from seeing each other.
|
BIN
torch/csrc/deploy/example/simple.pt
Normal file
BIN
torch/csrc/deploy/example/simple.pt
Normal file
Binary file not shown.
20
torch/csrc/deploy/example/trace_simple.py
Normal file
20
torch/csrc/deploy/example/trace_simple.py
Normal file
@ -0,0 +1,20 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, N, M):
|
||||
super(MyModule, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
||||
|
||||
def forward(self, input):
|
||||
output = self.weight + input
|
||||
return output
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("save_file", help="Where to save the model")
|
||||
args = parser.parse_args()
|
||||
|
||||
my_module = MyModule(10, 20)
|
||||
sm = torch.jit.script(my_module)
|
||||
sm.save(args.save_file)
|
115
torch/csrc/deploy/interpreter/CMakeLists.txt
Normal file
115
torch/csrc/deploy/interpreter/CMakeLists.txt
Normal file
@ -0,0 +1,115 @@
|
||||
SET(INTERPRETER_DIR "${DEPLOY_DIR}/interpreter" )
|
||||
SET(INTERPRETER_DIR "${DEPLOY_DIR}/interpreter" PARENT_SCOPE)
|
||||
|
||||
SET(PYTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../../../")
|
||||
|
||||
# Build cpython
|
||||
SET(PYTHON_INSTALL_DIR "${INTERPRETER_DIR}/cpython")
|
||||
SET(PYTHON_INC_DIR "${PYTHON_INSTALL_DIR}/include/python3.8")
|
||||
SET(PYTHON_LIB "${PYTHON_INSTALL_DIR}/lib/libpython3.8.a")
|
||||
SET(PYTHON_BIN "${PYTHON_INSTALL_DIR}/bin/python3")
|
||||
ExternalProject_Add(
|
||||
cpython
|
||||
PREFIX cpython
|
||||
GIT_REPOSITORY https://github.com/python/cpython.git
|
||||
GIT_TAG v3.8.6
|
||||
UPDATE_COMMAND ""
|
||||
BUILD_IN_SOURCE True
|
||||
CONFIGURE_COMMAND CFLAGS=-fPIC CPPFLAGS=-fPIC <SOURCE_DIR>/configure --prefix ${PYTHON_INSTALL_DIR}
|
||||
BUILD_COMMAND CFLAGS=-fPIC CPPFLAGS=-fPIC make -j8
|
||||
INSTALL_COMMAND make install
|
||||
BYPRODUCTS ${PYTHON_MODULES} ${PYTHON_LIB} ${PYTHON_BIN}
|
||||
LOG_OUTPUT_ON_FAILURE True
|
||||
)
|
||||
|
||||
# We find the built python modules, this is confusing because python build already outputs
|
||||
# the modules in a strange nested path, and then that path is relative to the
|
||||
# Cmake ExternalProject root in the cmake build dir.
|
||||
ExternalProject_Get_property(cpython SOURCE_DIR)
|
||||
SET(PYTHON_MODULE_DIR "${SOURCE_DIR}/build/temp.linux-x86_64-3.8/${SOURCE_DIR}/Modules")
|
||||
SET(PYTHON_STDLIB_DIR "${SOURCE_DIR}/Lib")
|
||||
SET(PYTHON_STDLIB "${PYTHON_INSTALL_DIR}/lib/libpython_stdlib3.8.a")
|
||||
# Then we use a hardcoded list of expected module names and include them in our lib
|
||||
include("CMakePythonModules.txt")
|
||||
ExternalProject_Add_Step(
|
||||
cpython
|
||||
archive_stdlib
|
||||
DEPENDEES install
|
||||
BYPRODUCTS ${PYTHON_STDLIB}
|
||||
COMMAND ar -rc ${PYTHON_STDLIB} ${PYTHON_MODULES}
|
||||
VERBATIM
|
||||
)
|
||||
# Get python typing extension, needed by torch
|
||||
SET(TYPING_PKG "${INTERPRETER_DIR}/third_party/typing_extensions.py")
|
||||
ExternalProject_Add(
|
||||
typing
|
||||
PREFIX typing
|
||||
GIT_REPOSITORY https://github.com/python/typing.git
|
||||
GIT_TAG 3.7.4.3
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND cp ../typing/typing_extensions/src_py3/typing_extensions.py ${TYPING_PKG}
|
||||
BYPRODUCTS ${TYPING_PKG}
|
||||
LOG_OUTPUT_ON_FAILURE True
|
||||
)
|
||||
|
||||
# Output files generated by freeze script, containing frozen bytecode
|
||||
SET(FROZEN_DIR "${INTERPRETER_DIR}/frozen")
|
||||
set(FROZEN_FILES
|
||||
${FROZEN_DIR}/main.c
|
||||
${FROZEN_DIR}/bytecode_0.c
|
||||
${FROZEN_DIR}/bytecode_1.c
|
||||
${FROZEN_DIR}/bytecode_2.c
|
||||
${FROZEN_DIR}/bytecode_3.c
|
||||
${FROZEN_DIR}/bytecode_4.c
|
||||
)
|
||||
# Packages to freeze: python stdlib, typing extension, and torch
|
||||
add_custom_command(
|
||||
OUTPUT ${FROZEN_FILES}
|
||||
WORKING_DIRECTORY ${INTERPRETER_DIR}
|
||||
COMMAND mkdir -p ${FROZEN_DIR}
|
||||
COMMAND ${PYTHON_BIN} freeze.py ${PYTHON_STDLIB_DIR} ${TYPING_PKG} ${PYTORCH_ROOT}/torch --oss --install_dir ${FROZEN_DIR} --verbose
|
||||
DEPENDS cpython typing
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# instantiate a library based on the objects that make up torch_python
|
||||
# make sure system python isn't used here
|
||||
target_include_directories(torch_python_obj BEFORE PRIVATE ${PYTHON_INC_DIR})
|
||||
add_library(torch_python_static STATIC $<TARGET_OBJECTS:torch_python_obj>)
|
||||
# Build the interpreter lib, designed to be standalone and dlopened
|
||||
# We bake the python and torch_python binding objs into libinterpreter
|
||||
set(LINKER_SCRIPT "${INTERPRETER_DIR}/hide_symbols.script")
|
||||
set(INTERPRETER_LIB_SOURCES
|
||||
${INTERPRETER_DIR}/interpreter.cpp
|
||||
${FROZEN_FILES}
|
||||
${LINKER_SCRIPT}
|
||||
)
|
||||
add_library(interpreter ${INTERPRETER_LIB_SOURCES} ${LINKER_SCRIPT})
|
||||
set_property(TARGET interpreter APPEND_STRING PROPERTY
|
||||
LINK_FLAGS " -Wl,--version-script=${LINKER_SCRIPT}")
|
||||
# need to ensure headers are present before any .cpp in interpreter are compiled,
|
||||
# but cpp themselves don't clearly depend on cpython so there is a race otherwise
|
||||
add_dependencies(interpreter cpython)
|
||||
target_compile_options(
|
||||
interpreter PRIVATE
|
||||
-fvisibility=hidden
|
||||
)
|
||||
target_include_directories(interpreter PRIVATE ${INTERPRETER_DIR})
|
||||
target_include_directories(interpreter PUBLIC ${PYTHON_INC_DIR})
|
||||
target_link_libraries(interpreter PRIVATE ${PYTHON_LIB} ${PYTHON_STDLIB} torch_python_static)
|
||||
target_link_libraries(interpreter PRIVATE crypt crypto ssl pthread dl util m z ffi lzma readline nsl ncursesw panelw) # for python builtins
|
||||
target_link_libraries(interpreter PRIVATE fmt::fmt-header-only protobuf::libprotobuf-lite)
|
||||
|
||||
# handy to have a standalone app to verify linkage and usage of interpreter before embedding it in another lib
|
||||
set(INTERPRETER_TEST_SOURCES
|
||||
${INTERPRETER_DIR}/test_main.cpp
|
||||
)
|
||||
add_executable(interpreter_test ${INTERPRETER_TEST_SOURCES})
|
||||
target_include_directories(interpreter_test PRIVATE ${PYTORCH_ROOT}/torch)
|
||||
target_include_directories(interpreter_test PRIVATE ${PYTHON_INC_DIR})
|
||||
target_link_libraries(interpreter_test PUBLIC gtest dl)
|
||||
# no-as-needed to ensure shm and torch are included to satisfy runtime dlopen
|
||||
# dependencies for libinterpreter, regardless of whether they are used in interpreter_test
|
||||
target_link_libraries(interpreter_test PUBLIC "-Wl,--no-as-needed" shm torch protobuf::libprotobuf-lite)
|
69
torch/csrc/deploy/interpreter/CMakePythonModules.txt
Normal file
69
torch/csrc/deploy/interpreter/CMakePythonModules.txt
Normal file
@ -0,0 +1,69 @@
|
||||
SET(PYTHON_MODULES
|
||||
${PYTHON_MODULE_DIR}/arraymodule.o
|
||||
${PYTHON_MODULE_DIR}/_asynciomodule.o
|
||||
${PYTHON_MODULE_DIR}/audioop.o
|
||||
${PYTHON_MODULE_DIR}/binascii.o
|
||||
${PYTHON_MODULE_DIR}/_bisectmodule.o
|
||||
${PYTHON_MODULE_DIR}/_blake2/blake2module.o ${PYTHON_MODULE_DIR}/_blake2/blake2b_impl.o ${PYTHON_MODULE_DIR}/_blake2/blake2s_impl.o
|
||||
${PYTHON_MODULE_DIR}/_bz2module.o
|
||||
${PYTHON_MODULE_DIR}/cmathmodule.o
|
||||
# ${PYTHON_MODULE_DIR}/_math.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_cn.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_hk.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_iso2022.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_jp.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_kr.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/_codecs_tw.o
|
||||
${PYTHON_MODULE_DIR}/_contextvarsmodule.o
|
||||
${PYTHON_MODULE_DIR}/_cryptmodule.o
|
||||
${PYTHON_MODULE_DIR}/_csv.o
|
||||
${PYTHON_MODULE_DIR}/_ctypes/_ctypes.o ${PYTHON_MODULE_DIR}/_ctypes/callbacks.o ${PYTHON_MODULE_DIR}/_ctypes/callproc.o ${PYTHON_MODULE_DIR}/_ctypes/stgdict.o ${PYTHON_MODULE_DIR}/_ctypes/cfield.o
|
||||
${PYTHON_MODULE_DIR}/_ctypes/_ctypes_test.o
|
||||
${PYTHON_MODULE_DIR}/_cursesmodule.o
|
||||
${PYTHON_MODULE_DIR}/_curses_panel.o
|
||||
${PYTHON_MODULE_DIR}/_datetimemodule.o
|
||||
${PYTHON_MODULE_DIR}/_decimal/_decimal.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/basearith.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/constants.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/context.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/convolute.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/crt.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/difradix2.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/fnt.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/fourstep.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/io.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/memory.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/mpdecimal.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/numbertheory.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/sixstep.o ${PYTHON_MODULE_DIR}/_decimal/libmpdec/transpose.o
|
||||
${PYTHON_MODULE_DIR}/_elementtree.o
|
||||
${PYTHON_MODULE_DIR}/fcntlmodule.o
|
||||
${PYTHON_MODULE_DIR}/grpmodule.o
|
||||
${PYTHON_MODULE_DIR}/_hashopenssl.o
|
||||
${PYTHON_MODULE_DIR}/_heapqmodule.o
|
||||
${PYTHON_MODULE_DIR}/_json.o
|
||||
${PYTHON_MODULE_DIR}/_lsprof.o
|
||||
${PYTHON_MODULE_DIR}/_lzmamodule.o
|
||||
${PYTHON_MODULE_DIR}/mathmodule.o
|
||||
${PYTHON_MODULE_DIR}/md5module.o
|
||||
${PYTHON_MODULE_DIR}/mmapmodule.o
|
||||
${PYTHON_MODULE_DIR}/cjkcodecs/multibytecodec.o
|
||||
${PYTHON_MODULE_DIR}/_multiprocessing/multiprocessing.o ${PYTHON_MODULE_DIR}/_multiprocessing/semaphore.o
|
||||
${PYTHON_MODULE_DIR}/nismodule.o
|
||||
${PYTHON_MODULE_DIR}/_opcode.o
|
||||
${PYTHON_MODULE_DIR}/ossaudiodev.o
|
||||
${PYTHON_MODULE_DIR}/parsermodule.o
|
||||
${PYTHON_MODULE_DIR}/_pickle.o
|
||||
${PYTHON_MODULE_DIR}/_posixsubprocess.o
|
||||
${PYTHON_MODULE_DIR}/pyexpat.o ${PYTHON_MODULE_DIR}/expat/xmlparse.o ${PYTHON_MODULE_DIR}/expat/xmlrole.o ${PYTHON_MODULE_DIR}/expat/xmltok.o
|
||||
${PYTHON_MODULE_DIR}/_queuemodule.o
|
||||
${PYTHON_MODULE_DIR}/_randommodule.o
|
||||
${PYTHON_MODULE_DIR}/readline.o
|
||||
${PYTHON_MODULE_DIR}/resource.o
|
||||
${PYTHON_MODULE_DIR}/selectmodule.o
|
||||
${PYTHON_MODULE_DIR}/sha1module.o
|
||||
${PYTHON_MODULE_DIR}/sha256module.o
|
||||
${PYTHON_MODULE_DIR}/_sha3/sha3module.o
|
||||
${PYTHON_MODULE_DIR}/sha512module.o
|
||||
${PYTHON_MODULE_DIR}/socketmodule.o
|
||||
${PYTHON_MODULE_DIR}/spwdmodule.o
|
||||
${PYTHON_MODULE_DIR}/_ssl.o
|
||||
${PYTHON_MODULE_DIR}/_struct.o
|
||||
${PYTHON_MODULE_DIR}/syslogmodule.o
|
||||
${PYTHON_MODULE_DIR}/termios.o
|
||||
${PYTHON_MODULE_DIR}/_testbuffer.o
|
||||
${PYTHON_MODULE_DIR}/_testcapimodule.o
|
||||
${PYTHON_MODULE_DIR}/_testimportmultiple.o
|
||||
${PYTHON_MODULE_DIR}/_testmultiphase.o
|
||||
${PYTHON_MODULE_DIR}/unicodedata.o
|
||||
${PYTHON_MODULE_DIR}/xxlimited.o
|
||||
${PYTHON_MODULE_DIR}/_xxtestfuzz/_xxtestfuzz.o ${PYTHON_MODULE_DIR}/_xxtestfuzz/fuzzer.o
|
||||
${PYTHON_MODULE_DIR}/zlibmodule.o
|
||||
)
|
269
torch/csrc/deploy/interpreter/freeze.py
Normal file
269
torch/csrc/deploy/interpreter/freeze.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""
|
||||
Freeze Python packages.
|
||||
|
||||
Freezing makes it possible to ship arbitrary Python modules as part of a C++
|
||||
library. The Python source of the module is compiled to bytecode and written
|
||||
to `.c` files, to be imported by Python's built-in FrozenImporter.
|
||||
|
||||
In a normal Python installation, FrozenImporter is only used to bootstrap the
|
||||
initialization of the import machinery. Python's importers are defined in
|
||||
Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be
|
||||
retrieved before any importers are available. Freezing the module bytecode
|
||||
resolves this circular dependency.
|
||||
|
||||
This script will freeze the Python standard library. It produces two things:
|
||||
- Bytecode files: A set of `.c` that define C variables containing Python bytecode.
|
||||
- Main file: A `main.c` file listing all of these modules in the right form to be
|
||||
consumed by FrozenImporter.
|
||||
|
||||
The library that wishes to these modules make them available to the local
|
||||
Python instance by extending `PyImport_FrozenModules` appropriately (see
|
||||
https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
MAIN_INCLUDES = """#include <Python.h>
|
||||
|
||||
"""
|
||||
|
||||
MAIN_PREFIX = """
|
||||
// Compiled standard library modules. These should be appended to the existing
|
||||
// `PyImport_FrozenModules` that ships with CPython.
|
||||
struct _frozen _PyImport_FrozenModules_torch[] = {
|
||||
"""
|
||||
|
||||
FAKE_PREFIX = """
|
||||
// Compiled standard library modules. These should be appended to the existing
|
||||
// `PyImport_FrozenModules` that ships with CPython.
|
||||
struct _frozen _PyImport_FrozenModules[] = {
|
||||
"""
|
||||
|
||||
MAIN_SUFFIX = """\
|
||||
{0, 0, 0} /* sentinel */
|
||||
};
|
||||
"""
|
||||
|
||||
# Exclude some standard library modules to:
|
||||
# 1. Slim down the final frozen lib.
|
||||
# 2. Remove functionality we don't want to support.
|
||||
DENY_LIST = [
|
||||
# Interface to unix databases
|
||||
"dbm",
|
||||
# ncurses bindings (terminal interfaces)
|
||||
"curses",
|
||||
# Tcl/Tk GUI
|
||||
"tkinter",
|
||||
"tkinter",
|
||||
# Tests for the standard library
|
||||
"test",
|
||||
"tests",
|
||||
"idle_test",
|
||||
"__phello__.foo.py",
|
||||
# importlib frozen modules. These are already baked into CPython.
|
||||
"_bootstrap.py",
|
||||
"_bootstrap_external.py",
|
||||
]
|
||||
|
||||
NUM_BYTECODE_FILES = 5
|
||||
|
||||
|
||||
def indent_msg(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
args[0].indent += 1
|
||||
ret = fn(*args, **kwargs)
|
||||
args[0].indent -= 1
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrozenModule:
|
||||
# The fully qualified module name, e.g. 'foo.bar.baz'
|
||||
module_name: str
|
||||
# The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz'
|
||||
c_name: str
|
||||
# The size of the C variable. Negative if this module is a package.
|
||||
size: int
|
||||
# The frozen bytecode
|
||||
bytecode: bytes
|
||||
|
||||
|
||||
class Freezer:
|
||||
def __init__(self, verbose: bool):
|
||||
self.frozen_modules: List[FrozenModule] = []
|
||||
self.indent: int = 0
|
||||
self.verbose: bool = verbose
|
||||
|
||||
def msg(self, path: Path, code: str):
|
||||
if not self.verbose:
|
||||
return
|
||||
# P: package dir
|
||||
# F: python file
|
||||
# S: skipped (not a package dir)
|
||||
# X: skipped (deny-listed)
|
||||
# N: skipped (not a python file)
|
||||
for i in range(self.indent):
|
||||
print(" ", end="")
|
||||
print(f"{code} {path}")
|
||||
|
||||
def write_bytecode(self, install_root):
|
||||
"""
|
||||
Write the `.c` files containing the frozen bytecode. Shard frozen
|
||||
modules evenly across the files.
|
||||
"""
|
||||
bytecode_file_names = [
|
||||
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
|
||||
]
|
||||
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
|
||||
it = itertools.cycle(bytecode_files)
|
||||
for m in self.frozen_modules:
|
||||
self.write_frozen(m, next(it))
|
||||
|
||||
for f in bytecode_files:
|
||||
f.close()
|
||||
|
||||
def write_main(self, install_root, oss):
|
||||
"""
|
||||
Write the `main.c` file containing a table enumerating all the
|
||||
frozen modules.
|
||||
"""
|
||||
with open(os.path.join(install_root, "main.c"), "w") as outfp:
|
||||
outfp.write(MAIN_INCLUDES)
|
||||
for m in self.frozen_modules:
|
||||
outfp.write(f"extern unsigned char {m.c_name}[];\n")
|
||||
|
||||
outfp.write(MAIN_PREFIX)
|
||||
for m in self.frozen_modules:
|
||||
outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n')
|
||||
outfp.write(MAIN_SUFFIX)
|
||||
if oss:
|
||||
outfp.write(FAKE_PREFIX)
|
||||
outfp.write(MAIN_SUFFIX)
|
||||
|
||||
def write_frozen(self, m: FrozenModule, outfp):
|
||||
"""
|
||||
Write a single frozen module's bytecode out to a C variable.
|
||||
"""
|
||||
outfp.write(f"unsigned char {m.c_name}[] = {{")
|
||||
for i in range(0, len(m.bytecode), 16):
|
||||
outfp.write("\n\t")
|
||||
for c in bytes(m.bytecode[i : i + 16]):
|
||||
outfp.write("%d," % c)
|
||||
outfp.write("\n};\n")
|
||||
|
||||
def compile_path(self, path: Path, top_package_path: Path):
|
||||
"""Generic entry point for compiling a Path object."""
|
||||
if path.is_dir():
|
||||
self.compile_package(path, top_package_path)
|
||||
else:
|
||||
self.compile_file(path, top_package_path)
|
||||
|
||||
@indent_msg
|
||||
def compile_package(self, path: Path, top_package_path: Path):
|
||||
"""Compile all the files within a Python package dir."""
|
||||
assert path.is_dir()
|
||||
if path.name in DENY_LIST:
|
||||
self.msg(path, "X")
|
||||
return
|
||||
|
||||
# Python packages are directories that have __init__.py in them.
|
||||
is_package_dir = any([child.name == "__init__.py" for child in path.iterdir()])
|
||||
if not is_package_dir:
|
||||
self.msg(path, "S")
|
||||
return
|
||||
|
||||
self.msg(path, "P")
|
||||
# Recursively compile all children in this dir
|
||||
for child in path.iterdir():
|
||||
self.compile_path(child, top_package_path)
|
||||
|
||||
def get_module_qualname(self, file_path: Path, top_package_path: Path) -> List[str]:
|
||||
# `path` looks like 'Lib/foo/bar/baz.py'
|
||||
|
||||
# chop off 'Lib/' to get something that represents a Python module hierarchy.
|
||||
# e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz'
|
||||
normalized_path = file_path.relative_to(top_package_path.parent)
|
||||
|
||||
if normalized_path.name == "__init__.py":
|
||||
# Special handling for `__init__.py`. In this case, this file
|
||||
# specifies that the containing directory should be treated as a package.
|
||||
# For 'foo/bar/baz/__init__.py':
|
||||
# - The module name is 'baz'
|
||||
module_basename = normalized_path.parent.name
|
||||
# - The parent is foo.bar (need to shave off the 'baz')
|
||||
module_parent = normalized_path.parent.parent.parts
|
||||
else:
|
||||
module_basename = normalized_path.stem
|
||||
module_parent = normalized_path.parent.parts
|
||||
return list(module_parent) + [module_basename]
|
||||
|
||||
@indent_msg
|
||||
def compile_file(self, path: Path, top_package_path: Path):
|
||||
"""
|
||||
Compile a Python source file to frozen bytecode. Append the result to
|
||||
`self.frozen_modules`.
|
||||
"""
|
||||
assert path.is_file()
|
||||
if path.suffix != ".py":
|
||||
self.msg(path, "N")
|
||||
return
|
||||
|
||||
if path.name in DENY_LIST:
|
||||
self.msg(path, "X")
|
||||
return
|
||||
|
||||
self.msg(path, "F")
|
||||
module_qualname = self.get_module_qualname(path, top_package_path)
|
||||
module_mangled_name = "__".join(module_qualname)
|
||||
c_name = "M_" + module_mangled_name
|
||||
|
||||
with open(path, "r") as src_file:
|
||||
co = compile(src_file.read(), path, "exec")
|
||||
|
||||
bytecode = marshal.dumps(co)
|
||||
size = len(bytecode)
|
||||
if path.name == '__init__.py':
|
||||
# Python packages are signified by negative size.
|
||||
size = -size
|
||||
self.frozen_modules.append(
|
||||
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
|
||||
)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compile py source")
|
||||
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
|
||||
parser.add_argument("--install_dir", help="Root directory for all output files")
|
||||
parser.add_argument("--fbcode_dir", help="Root directory for all output files")
|
||||
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
f = Freezer(args.verbose)
|
||||
|
||||
for p in args.paths:
|
||||
if args.fbcode_dir:
|
||||
p = os.path.join(args.fbcode_dir, p)
|
||||
path = Path(p)
|
||||
if path.is_dir() and not Path.exists(path / '__init__.py'):
|
||||
# this 'top level path p' is a standard directory containing modules,
|
||||
# not a module itself
|
||||
# each 'mod' could be a dir containing __init__.py or .py file
|
||||
for mod in path.glob("*"):
|
||||
f.compile_path(mod, mod)
|
||||
else:
|
||||
f.compile_path(path, path)
|
||||
|
||||
f.write_bytecode(args.install_dir)
|
||||
f.write_main(args.install_dir, args.oss)
|
5
torch/csrc/deploy/interpreter/hide_symbols.script
Normal file
5
torch/csrc/deploy/interpreter/hide_symbols.script
Normal file
@ -0,0 +1,5 @@
|
||||
INTERPRETER_0.1 {
|
||||
global:
|
||||
initialize_interface;
|
||||
local: *; # hide everything else
|
||||
};
|
324
torch/csrc/deploy/interpreter/interpreter.cpp
Normal file
324
torch/csrc/deploy/interpreter/interpreter.cpp
Normal file
@ -0,0 +1,324 @@
|
||||
#include <dlfcn.h>
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <iostream>
|
||||
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
|
||||
#include <pybind11/embed.h>
|
||||
#include <cstdio>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
// TODO this should come from cmake
|
||||
#define DEBUG 0
|
||||
template<typename T>
|
||||
const auto PYOBJ_ASSERT(T obj) {
|
||||
#if (DEBUG == 1)
|
||||
if (NULL == obj) {
|
||||
PyErr_Print();
|
||||
}
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(NULL != obj);
|
||||
}
|
||||
|
||||
static wchar_t* program;
|
||||
|
||||
#define FOREACH_LIBRARY(_) \
|
||||
_(array) \
|
||||
_(_asyncio) \
|
||||
_(audioop) \
|
||||
_(binascii) \
|
||||
_(_bisect) \
|
||||
_(_blake2) \
|
||||
_(_bz2) \
|
||||
_(cmath) \
|
||||
_(_codecs_cn) \
|
||||
_(_codecs_hk) \
|
||||
_(_codecs_iso2022) \
|
||||
_(_codecs_jp) \
|
||||
_(_codecs_kr) \
|
||||
_(_codecs_tw) \
|
||||
_(_contextvars) \
|
||||
_(_crypt) \
|
||||
_(_csv) \
|
||||
_(_ctypes) \
|
||||
_(_ctypes_test) \
|
||||
_(_curses) \
|
||||
_(_curses_panel) \
|
||||
_(_datetime) \
|
||||
_(_decimal) \
|
||||
_(_elementtree) \
|
||||
_(fcntl) \
|
||||
_(grp) \
|
||||
_(_hashlib) \
|
||||
_(_heapq) \
|
||||
_(_json) \
|
||||
_(_lsprof) \
|
||||
_(_lzma) \
|
||||
_(math) \
|
||||
_(_md5) \
|
||||
_(mmap) \
|
||||
_(_multibytecodec) \
|
||||
_(_multiprocessing) \
|
||||
_(nis) \
|
||||
_(_opcode) \
|
||||
_(ossaudiodev) \
|
||||
_(parser) \
|
||||
_(_pickle) \
|
||||
_(_posixsubprocess) \
|
||||
_(pyexpat) \
|
||||
_(_queue) \
|
||||
_(_random) \
|
||||
_(readline) \
|
||||
_(resource) \
|
||||
_(select) \
|
||||
_(_sha1) \
|
||||
_(_sha256) \
|
||||
_(_sha3) \
|
||||
_(_sha512) \
|
||||
_(_socket) \
|
||||
_(spwd) \
|
||||
_(_ssl) \
|
||||
_(_struct) \
|
||||
_(syslog) \
|
||||
_(termios) \
|
||||
_(_testbuffer) \
|
||||
_(_testcapi) \
|
||||
_(_testimportmultiple) \
|
||||
_(_testmultiphase) \
|
||||
_(unicodedata) \
|
||||
_(xxlimited) \
|
||||
_(_xxtestfuzz) \
|
||||
_(zlib)
|
||||
|
||||
#define DECLARE_LIBRARY_INIT(name) extern "C" PyObject* PyInit_##name(void);
|
||||
FOREACH_LIBRARY(DECLARE_LIBRARY_INIT)
|
||||
#undef DECLARE_LIBRARY_INIT
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) void initialize_interface(
|
||||
InterpreterImpl* s) {
|
||||
#define INITIALIZE_MEMBER(func) s->func = func;
|
||||
FOREACH_INTERFACE_FUNCTION(INITIALIZE_MEMBER)
|
||||
#undef INITIALIZE_MEMBER
|
||||
}
|
||||
|
||||
// These numbers of modules should not change as long as the cpython version
|
||||
// embedded in the build remains fixed
|
||||
static const size_t NUM_FROZEN_PY_BUILTIN_MODULES = 6;
|
||||
static const size_t NUM_FROZEN_PY_STDLIB_MODULES = 680;
|
||||
|
||||
// We need to preserve the existing FrozenModules list, since it includes
|
||||
// important importlib machinery. This code is adapted from the similar
|
||||
// `PyImport_ExtendInittab`.
|
||||
int extendFrozenModules(struct _frozen *frozenpython, struct _frozen *frozentorch) {
|
||||
struct _frozen *p = nullptr;
|
||||
size_t a = 0, b = 0, c = 0;
|
||||
int res = 0;
|
||||
|
||||
/* Count the number of entries in both tables */
|
||||
for (a = 0; frozenpython[a].name != nullptr; a++) {
|
||||
// std::cout << "frozenpython[" << a << "]: " << frozenpython[a].name << std::endl;
|
||||
}
|
||||
for (b = 0; frozentorch[b].name != nullptr; b++) {
|
||||
// std::cout << "frozentorch[" << b << "]: " << frozentorch[b].name << std::endl;
|
||||
}
|
||||
for (c = 0; PyImport_FrozenModules[c].name != nullptr; c++) {
|
||||
// std::cout << "oldfrozen[" << c << "]: " << PyImport_FrozenModules[c].name << std::endl;
|
||||
}
|
||||
|
||||
// Num frozen builtins shouldn't change (unless modifying the underlying cpython version)
|
||||
TORCH_INTERNAL_ASSERT(c == NUM_FROZEN_PY_BUILTIN_MODULES, "Missing python builtin frozen modules");
|
||||
// Check a+b together since in OSS a is empty and b contains stdlib+torch, while
|
||||
// in fbcode they are separated due to thirdparty2 frozenpython.
|
||||
// No fixed number of torch modules to check for, but there should be at least one.
|
||||
TORCH_INTERNAL_ASSERT(a + b > NUM_FROZEN_PY_STDLIB_MODULES + 1, "Missing frozen python stdlib or torch modules");
|
||||
|
||||
/* Allocate new memory for the combined table */
|
||||
if (a + b + c <= SIZE_MAX / sizeof(struct _frozen) - 1) {
|
||||
size_t size = sizeof(struct _frozen) * (a + b + c + 1);
|
||||
p = (_frozen*)PyMem_Realloc(p, size);
|
||||
}
|
||||
if (p == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
/* Copy the tables into the new memory */
|
||||
memcpy(p, PyImport_FrozenModules, (c + 1) * sizeof(struct _frozen));
|
||||
memcpy(p + c, frozenpython, (a + 1) * sizeof(struct _frozen));
|
||||
memcpy(p + a + c, frozentorch, (b + 1) * sizeof(struct _frozen));
|
||||
PyImport_FrozenModules = p;
|
||||
return res;
|
||||
}
|
||||
|
||||
// We need to register a custom finder because we are registering `torch._C` as
|
||||
// a built-in module, and it will otherwise get skipped by the default importer.
|
||||
const char* finder = R"RAW(
|
||||
import sys
|
||||
# Remove the path-based importer, as we don't want our isolated interpreter to read the file system
|
||||
sys.meta_path = sys.meta_path[:-1]
|
||||
|
||||
class F:
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
if fullname == 'torch._C':
|
||||
return sys.meta_path[1].find_spec('torch._C', None, None)
|
||||
return None
|
||||
sys.meta_path.insert(0, F())
|
||||
|
||||
# make loader importable
|
||||
)RAW";
|
||||
|
||||
const char* sysprint = R"RAW(
|
||||
import sys
|
||||
print("exec_prefix:", sys.base_exec_prefix)
|
||||
print("_base_executable:", sys._base_executable)
|
||||
print("base_prefix:", sys.base_prefix)
|
||||
print("exec_prefix:", sys.exec_prefix)
|
||||
print("executable:", sys.executable)
|
||||
print("path:", sys.path)
|
||||
print("prefix:", sys.prefix)
|
||||
|
||||
)RAW";
|
||||
|
||||
extern "C" PyObject* initModule(void);
|
||||
extern "C" struct _frozen _PyImport_FrozenModules[];
|
||||
extern "C" struct _frozen _PyImport_FrozenModules_torch[];
|
||||
|
||||
static std::atomic<size_t> s_id;
|
||||
std::map<size_t, py::object> forwards;
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
|
||||
}
|
||||
|
||||
void startup() {
|
||||
#define APPEND_INIT(name) PyImport_AppendInittab(#name, PyInit_##name);
|
||||
FOREACH_LIBRARY(APPEND_INIT)
|
||||
#undef APPEND_INIT
|
||||
PyImport_AppendInittab("torch._C", initModule);
|
||||
|
||||
int ret = extendFrozenModules(_PyImport_FrozenModules, _PyImport_FrozenModules_torch);
|
||||
TORCH_INTERNAL_ASSERT(ret == 0);
|
||||
|
||||
PyPreConfig preconfig;
|
||||
PyPreConfig_InitIsolatedConfig(&preconfig);
|
||||
PyStatus status = Py_PreInitialize(&preconfig);
|
||||
TORCH_INTERNAL_ASSERT(!PyStatus_Exception(status))
|
||||
|
||||
PyConfig config;
|
||||
PyConfig_InitIsolatedConfig(&config);
|
||||
|
||||
// Completely blank out the path configuration. This ensures we have complete
|
||||
// control of how our embedded Python searches for modules, and we will never
|
||||
// consult the external filesystem. See:
|
||||
// https://docs.python.org/3/c-api/init_config.html#path-configuration
|
||||
config.site_import = 0;
|
||||
|
||||
status = PyConfig_SetString(&config, &config.base_exec_prefix, L"");
|
||||
status = PyConfig_SetString(&config, &config.base_executable, L"torch_deploy");
|
||||
status = PyConfig_SetString(&config, &config.base_prefix, L"");
|
||||
status = PyConfig_SetString(&config, &config.exec_prefix, L"");
|
||||
status = PyConfig_SetString(&config, &config.executable, L"torch_deploy");
|
||||
status = PyConfig_SetString(&config, &config.prefix, L"");
|
||||
|
||||
|
||||
config.module_search_paths_set = 1;
|
||||
std::array<wchar_t*, 0> module_search_paths = {};
|
||||
status = PyConfig_SetWideStringList(
|
||||
&config, &config.module_search_paths, 0, module_search_paths.data());
|
||||
|
||||
status = Py_InitializeFromConfig(&config);
|
||||
PyConfig_Clear(&config);
|
||||
TORCH_INTERNAL_ASSERT(!PyStatus_Exception(status))
|
||||
|
||||
// Uncomment to debug python config
|
||||
// PyRun_SimpleString(sysprint);
|
||||
|
||||
PyRun_SimpleString(finder);
|
||||
// Release the GIL that PyInitialize acquires
|
||||
PyEval_SaveThread();
|
||||
}
|
||||
|
||||
void teardown() {
|
||||
PyGILState_Ensure();
|
||||
|
||||
if (Py_FinalizeEx() < 0) {
|
||||
std::cout << "IT BROKE SO WE ARE EXITING\n";
|
||||
exit(120);
|
||||
}
|
||||
PyMem_RawFree(program);
|
||||
}
|
||||
|
||||
__attribute__((destructor)) void deinit() {}
|
||||
|
||||
void run_some_python(const char* code) {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
if (PyRun_SimpleString(code) == -1) {
|
||||
throw std::runtime_error("python eval failed\n");
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
||||
void run_python_file(const char* code) {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
FILE* f = fopen(code, "r");
|
||||
if (PyRun_SimpleFile(f, code) == -1) {
|
||||
throw std::runtime_error("python eval failed\n");
|
||||
}
|
||||
fclose(f);
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
||||
|
||||
size_t load_model(const char* filename, bool hermetic) {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
TORCH_INTERNAL_ASSERT(PyGILState_Check() == 1);
|
||||
std::string code;
|
||||
|
||||
if (hermetic) {
|
||||
code = fmt::format(R"(
|
||||
from torch.package import PackageImporter
|
||||
|
||||
i = PackageImporter('{}')
|
||||
model = i.load_pickle('model', 'model.pkl')
|
||||
)", filename);
|
||||
} else {
|
||||
code = std::string("model = torch.jit.load('") +
|
||||
std::string(filename) + std::string("')");
|
||||
}
|
||||
py::exec(code);
|
||||
|
||||
auto id = ++s_id;
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
return id;
|
||||
}
|
||||
|
||||
at::Tensor forward_model(size_t model_id, at::Tensor const & input) {
|
||||
at::Tensor output;
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(PyGILState_Check() == 1);
|
||||
auto forward = py::globals()["model"].attr("forward");
|
||||
|
||||
py::object py_output = forward(input);
|
||||
// TODO is this going to leak?
|
||||
// added it to prevent crash wehn using 'output' tensor in callee of
|
||||
// forward()
|
||||
py_output.inc_ref();
|
||||
output = py::cast<at::Tensor>(py_output);
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
return output;
|
||||
// return input;
|
||||
}
|
67
torch/csrc/deploy/interpreter/interpreter.h
Normal file
67
torch/csrc/deploy/interpreter/interpreter.h
Normal file
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
#include <dlfcn.h>
|
||||
#include <unistd.h>
|
||||
#include <experimental/filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
|
||||
|
||||
|
||||
class Interpreter : public InterpreterImpl {
|
||||
private:
|
||||
std::string library_name_;
|
||||
void* handle_;
|
||||
|
||||
public:
|
||||
Interpreter() : handle_(nullptr) {
|
||||
char library_name[L_tmpnam];
|
||||
library_name_ = library_name;
|
||||
char* libinterpreter_path = std::getenv("LIBINTERPRETER_PATH");
|
||||
if (libinterpreter_path == nullptr) {
|
||||
throw std::runtime_error("libinterpreter_path is NULL, set LIBINTERPRETER_PATH env.");
|
||||
}
|
||||
std::tmpnam(library_name);
|
||||
{
|
||||
std::ifstream src(libinterpreter_path, std::ios::binary);
|
||||
std::ofstream dst(library_name, std::ios::binary);
|
||||
dst << src.rdbuf();
|
||||
}
|
||||
handle_ = dlopen(library_name, RTLD_LOCAL | RTLD_LAZY);
|
||||
if (!handle_) {
|
||||
throw std::runtime_error(dlerror());
|
||||
}
|
||||
|
||||
// technically, we can unlike the library right after dlopen, and this is
|
||||
// better for cleanup because even if we crash the library doesn't stick
|
||||
// around. However, its crap for debugging because gdb can't find the
|
||||
// symbols if the library is no longer present.
|
||||
unlink(library_name_.c_str());
|
||||
|
||||
void* initialize_interface = dlsym(handle_, "initialize_interface");
|
||||
if (!initialize_interface) {
|
||||
throw std::runtime_error("Unable to load initialize_interface function from interpreter lib.");
|
||||
}
|
||||
((void (*)(InterpreterImpl*))initialize_interface)(this);
|
||||
|
||||
this->startup();
|
||||
|
||||
// the actual torch loading process is not thread safe, by doing it
|
||||
// in the constructor before we have multiple worker threads, then we
|
||||
// ensure it doesn't race.
|
||||
run_some_python("import torch");
|
||||
}
|
||||
~Interpreter() {
|
||||
if (handle_) {
|
||||
this->teardown();
|
||||
|
||||
// it segfaults its face off trying to unload, but it's not clear
|
||||
// if this is something we caused of if libtorch_python would also do the
|
||||
// same if it were opened/closed a lot...
|
||||
dlclose(handle_);
|
||||
}
|
||||
}
|
||||
Interpreter(const Interpreter&) = delete;
|
||||
};
|
26
torch/csrc/deploy/interpreter/interpreter_impl.h
Normal file
26
torch/csrc/deploy/interpreter/interpreter_impl.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// NOTE- if adding new interface functions,
|
||||
// update interpreter.cpp initialize_interface.
|
||||
size_t load_model(const char* model_file, bool hermetic=false);
|
||||
at::Tensor forward_model(size_t model_id, at::Tensor const & input);
|
||||
void run_some_python(const char* code);
|
||||
void startup();
|
||||
void teardown();
|
||||
void run_python_file(const char* code);
|
||||
|
||||
|
||||
#define FOREACH_INTERFACE_FUNCTION(_) \
|
||||
_(load_model) \
|
||||
_(forward_model) \
|
||||
_(run_some_python) \
|
||||
_(startup) \
|
||||
_(teardown) \
|
||||
_(run_python_file)
|
||||
|
||||
struct InterpreterImpl {
|
||||
#define DEFINE_POINTER(func) decltype(&::func) func;
|
||||
FOREACH_INTERFACE_FUNCTION(DEFINE_POINTER)
|
||||
#undef DEFINE_POINTER
|
||||
};
|
49
torch/csrc/deploy/interpreter/test_main.cpp
Normal file
49
torch/csrc/deploy/interpreter/test_main.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <torch/script.h>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/csrc/deploy/interpreter/interpreter.h>
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
int rc = RUN_ALL_TESTS();
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
TEST(Interpreter, Sanity) {
|
||||
ASSERT_TRUE(true);
|
||||
}
|
||||
|
||||
TEST(Interpreter, Hello) {
|
||||
Interpreter interp;
|
||||
interp.run_some_python("print('hello from first interpeter!')");
|
||||
|
||||
Interpreter interp2;
|
||||
interp2.run_some_python("print('hello from second interpeter!')");
|
||||
}
|
||||
|
||||
void compare_torchpy_jit(const char* model_filename, at::Tensor const & input) {
|
||||
Interpreter interp;
|
||||
// Test
|
||||
auto model_id = interp.load_model(model_filename, false);
|
||||
at::Tensor output = interp.forward_model(model_id, input);
|
||||
|
||||
// Reference
|
||||
auto ref_model = torch::jit::load(model_filename);
|
||||
std::vector<torch::jit::IValue> ref_inputs;
|
||||
ref_inputs.emplace_back(torch::jit::IValue(input));
|
||||
at::Tensor ref_output = ref_model.forward(ref_inputs).toTensor();
|
||||
|
||||
ASSERT_TRUE(ref_output.equal(output));
|
||||
}
|
||||
|
||||
TEST(Interpreter, SimpleModel) {
|
||||
char* model_path = std::getenv("SIMPLE_MODEL_PATH");
|
||||
ASSERT_NE(model_path, nullptr);
|
||||
const int A = 10, B = 20;
|
||||
compare_torchpy_jit(
|
||||
model_path, torch::ones(at::IntArrayRef({A, B})));
|
||||
}
|
2
torch/csrc/deploy/interpreter/third_party/README.md
vendored
Normal file
2
torch/csrc/deploy/interpreter/third_party/README.md
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
Python libraries that we want to package along with the Python implementation
|
||||
bundled in libinterpreter.
|
@ -61,7 +61,11 @@ void clearAllPrePasses() {
|
||||
|
||||
// LEGACY CALL
|
||||
RegisterPostPass::RegisterPostPass(GraphPass p) {
|
||||
registerPass(std::move(p));
|
||||
id_ = registerPass(std::move(p));
|
||||
}
|
||||
|
||||
RegisterPostPass::~RegisterPostPass() {
|
||||
clearPostPass(id_);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -52,6 +52,10 @@ TORCH_API void clearAllPrePasses();
|
||||
// LEGACY CALL
|
||||
struct TORCH_API RegisterPostPass {
|
||||
RegisterPostPass(GraphPass p);
|
||||
~RegisterPostPass();
|
||||
|
||||
private:
|
||||
GraphPassNameType id_;
|
||||
};
|
||||
|
||||
using RegisterPass = RegisterPostPass;
|
||||
|
@ -113,6 +113,10 @@ def _lazy_call(callable):
|
||||
if is_initialized():
|
||||
callable()
|
||||
else:
|
||||
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||||
# file system to get traceback info. Patch linecache or do something
|
||||
# else here if this ends up being important.
|
||||
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
from .throughput_benchmark import ThroughputBenchmark
|
||||
|
||||
import os.path as _osp
|
||||
import sys
|
||||
|
||||
# Set the module for a given object for nicer printing
|
||||
def set_module(obj, mod):
|
||||
@ -9,5 +10,8 @@ def set_module(obj, mod):
|
||||
raise TypeError("The mod argument should be a string")
|
||||
obj.__module__ = mod
|
||||
|
||||
#: Path to folder containing CMake definitions for Torch package
|
||||
cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), 'share', 'cmake')
|
||||
if sys.executable == "torch_deploy":
|
||||
# not valid inside torch_deploy interpreter, no paths exists for frozen modules
|
||||
cmake_prefix_path = None
|
||||
else:
|
||||
cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), 'share', 'cmake')
|
||||
|
Reference in New Issue
Block a user