mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Move flatbuffer related python C bindings to script_init (#97476)
Summary: Extra C binding module for flatbuffer was introduced because not all dependencies of Pytorch want (or can) bundle in flatbuffer. However, flatbuffer is in by default now so this separate binding is not longer needed. Test Plan: existing unit tests Differential Revision: D44352583 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97476 Approved by: https://github.com/dbort
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8cc8ffebc
commit
b895a0a675
10
setup.py
10
setup.py
@ -923,16 +923,7 @@ def configure_extension_build():
|
||||
include_dirs=[],
|
||||
library_dirs=library_dirs,
|
||||
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
|
||||
C_flatbuffer = Extension("torch._C_flatbuffer",
|
||||
libraries=main_libraries,
|
||||
sources=["torch/csrc/stub_with_flatbuffer.c"],
|
||||
language='c',
|
||||
extra_compile_args=main_compile_args + extra_compile_args,
|
||||
include_dirs=[],
|
||||
library_dirs=library_dirs,
|
||||
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
|
||||
extensions.append(C)
|
||||
extensions.append(C_flatbuffer)
|
||||
|
||||
# These extensions are built by cmake and copied manually in build_extensions()
|
||||
# inside the build_ext implementation
|
||||
@ -1066,7 +1057,6 @@ def main():
|
||||
'bin/*',
|
||||
'test/*',
|
||||
'_C/*.pyi',
|
||||
'_C_flatbuffer/*.pyi',
|
||||
'cuda/*.pyi',
|
||||
'optim/*.pyi',
|
||||
'autograd/*.pyi',
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
import torch._C_flatbuffer
|
||||
import torch._C
|
||||
|
||||
from torch.ao.quantization import (
|
||||
default_dynamic_qconfig,
|
||||
@ -443,8 +443,8 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
||||
|
||||
# Now serialize to flabuffer and load from fb and check
|
||||
dict: Dict[str, str] = {}
|
||||
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
|
||||
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
|
||||
bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
|
||||
m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
|
||||
fb_output = m(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, fb_output))
|
||||
|
||||
|
@ -44,9 +44,6 @@ set(TORCH_PYTHON_SRCS
|
||||
)
|
||||
append_filelist("libtorch_python_core_sources" TORCH_PYTHON_SRCS)
|
||||
|
||||
list(APPEND TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp)
|
||||
|
||||
# NB: This has to match the condition under which the JIT test directory
|
||||
# is included (at the time of writing that's in caffe2/CMakeLists.txt).
|
||||
if(BUILD_TEST)
|
||||
@ -327,9 +324,6 @@ set_source_files_properties(
|
||||
# Disable certain warnings for GCC-9.X
|
||||
if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0))
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/Module.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
|
||||
set_source_files_properties(
|
||||
${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
|
||||
endif()
|
||||
|
||||
|
@ -1839,3 +1839,13 @@ class CapturedTraceback:
|
||||
pass
|
||||
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
|
||||
def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...
|
||||
|
||||
def _load_mobile_module_from_file(filename: str): ...
|
||||
def _load_mobile_module_from_bytes(bytes_: bytes): ...
|
||||
def _load_jit_module_from_file(filename: str): ...
|
||||
def _load_jit_module_from_bytes(bytes_: bytes): ...
|
||||
def _save_mobile_module(m: LiteScriptModule, filename: str): ...
|
||||
def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ...
|
||||
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
||||
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
|
||||
def _get_module_info_from_flatbuffer(data: bytes): ...
|
||||
|
@ -1,131 +0,0 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <libshm.h>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <Python.h> // NOLINT
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#include <torch/csrc/jit/python/module_python.h>
|
||||
#include <torch/csrc/jit/python/python_ivalue.h>
|
||||
#include <torch/csrc/jit/python/python_sugared_value.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using torch::jit::kFlatbufferDataAlignmentBytes;
|
||||
|
||||
static std::shared_ptr<char> copyStr(const std::string& bytes) {
|
||||
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
|
||||
kFlatbufferDataAlignmentBytes;
|
||||
#ifdef _WIN32
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
|
||||
_aligned_free);
|
||||
#elif defined(__APPLE__)
|
||||
void* p;
|
||||
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
|
||||
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
|
||||
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
|
||||
#else
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
|
||||
free);
|
||||
#endif
|
||||
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
|
||||
return bytes_copy;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
PyObject* initModuleFlatbuffer() {
|
||||
using namespace torch::jit;
|
||||
PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
|
||||
static struct PyModuleDef torchmodule = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._C_flatbuffer",
|
||||
nullptr,
|
||||
-1,
|
||||
m,
|
||||
}; // NOLINT
|
||||
PyObject* module = PyModule_Create(&torchmodule);
|
||||
auto pym = py::handle(module).cast<py::module>();
|
||||
pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
|
||||
return torch::jit::load_mobile_module_from_file(filename);
|
||||
});
|
||||
pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
|
||||
auto bytes_copy = copyStr(bytes);
|
||||
return torch::jit::parse_and_initialize_mobile_module(
|
||||
bytes_copy, bytes.size());
|
||||
});
|
||||
pym.def("_load_jit_module_from_file", [](const std::string& filename) {
|
||||
ExtraFilesMap extra_files = ExtraFilesMap();
|
||||
return torch::jit::load_jit_module_from_file(filename, extra_files);
|
||||
});
|
||||
pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
|
||||
auto bytes_copy = copyStr(bytes);
|
||||
ExtraFilesMap extra_files = ExtraFilesMap();
|
||||
return torch::jit::parse_and_initialize_jit_module(
|
||||
bytes_copy, bytes.size(), extra_files);
|
||||
});
|
||||
pym.def(
|
||||
"_save_mobile_module",
|
||||
[](const torch::jit::mobile::Module& module,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
return torch::jit::save_mobile_module(module, filename, _extra_files);
|
||||
});
|
||||
pym.def(
|
||||
"_save_jit_module",
|
||||
[](const torch::jit::Module& module,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
return torch::jit::save_jit_module(module, filename, _extra_files);
|
||||
});
|
||||
pym.def(
|
||||
"_save_mobile_module_to_bytes",
|
||||
[](const torch::jit::mobile::Module& module,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
auto detached_buffer =
|
||||
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(detached_buffer->data()),
|
||||
detached_buffer->size());
|
||||
});
|
||||
pym.def(
|
||||
"_save_jit_module_to_bytes",
|
||||
[](const torch::jit::Module& module,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
auto detached_buffer =
|
||||
torch::jit::save_jit_module_to_bytes(module, _extra_files);
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(detached_buffer->data()),
|
||||
detached_buffer->size());
|
||||
});
|
||||
pym.def(
|
||||
"_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
py::dict result;
|
||||
mobile::ModuleInfo minfo =
|
||||
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
|
||||
result["bytecode_version"] = minfo.bytecode_version;
|
||||
result["operator_version"] = minfo.operator_version;
|
||||
result["function_names"] = minfo.function_names;
|
||||
result["type_names"] = minfo.type_names;
|
||||
result["opname_to_num_args"] = minfo.opname_to_num_args;
|
||||
return result;
|
||||
});
|
||||
|
||||
return module;
|
||||
}
|
@ -14,6 +14,7 @@
|
||||
#include <torch/csrc/jit/mobile/compatibility/backport.h>
|
||||
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/quantization.h>
|
||||
@ -25,6 +26,7 @@
|
||||
#include <torch/csrc/jit/python/python_ivalue.h>
|
||||
#include <torch/csrc/jit/python/python_sugared_value.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
@ -706,6 +708,30 @@ void pyCompilationUnitDefine(
|
||||
}
|
||||
}
|
||||
|
||||
// This function will copy bytes into a shared_ptr of chars aligned
|
||||
// at kFlatbufferDataAlignmentBytes boundary (currently 16).
|
||||
// This is required because tensors need to be aligned at 16 bytes boundary.
|
||||
static std::shared_ptr<char> copyStr(const std::string& bytes) {
|
||||
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
|
||||
kFlatbufferDataAlignmentBytes;
|
||||
#ifdef _WIN32
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
|
||||
_aligned_free);
|
||||
#elif defined(__APPLE__)
|
||||
void* p;
|
||||
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
|
||||
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
|
||||
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
|
||||
#else
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
|
||||
free);
|
||||
#endif
|
||||
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
|
||||
return bytes_copy;
|
||||
}
|
||||
|
||||
void initJitScriptBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
@ -2327,6 +2353,71 @@ void initJitScriptBindings(PyObject* module) {
|
||||
_save_parameters(map, filename, use_flatbuffer);
|
||||
});
|
||||
|
||||
m.def("_load_mobile_module_from_file", [](const std::string& filename) {
|
||||
return torch::jit::load_mobile_module_from_file(filename);
|
||||
});
|
||||
m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
|
||||
auto bytes_copy = copyStr(bytes);
|
||||
return torch::jit::parse_and_initialize_mobile_module(
|
||||
bytes_copy, bytes.size());
|
||||
});
|
||||
m.def("_load_jit_module_from_file", [](const std::string& filename) {
|
||||
ExtraFilesMap extra_files = ExtraFilesMap();
|
||||
return torch::jit::load_jit_module_from_file(filename, extra_files);
|
||||
});
|
||||
m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
|
||||
auto bytes_copy = copyStr(bytes);
|
||||
ExtraFilesMap extra_files = ExtraFilesMap();
|
||||
return torch::jit::parse_and_initialize_jit_module(
|
||||
bytes_copy, bytes.size(), extra_files);
|
||||
});
|
||||
m.def(
|
||||
"_save_mobile_module",
|
||||
[](const torch::jit::mobile::Module& module,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
return torch::jit::save_mobile_module(module, filename, _extra_files);
|
||||
});
|
||||
m.def(
|
||||
"_save_jit_module",
|
||||
[](const torch::jit::Module& module,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
return torch::jit::save_jit_module(module, filename, _extra_files);
|
||||
});
|
||||
m.def(
|
||||
"_save_mobile_module_to_bytes",
|
||||
[](const torch::jit::mobile::Module& module,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
auto detached_buffer =
|
||||
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(detached_buffer->data()),
|
||||
detached_buffer->size());
|
||||
});
|
||||
m.def(
|
||||
"_save_jit_module_to_bytes",
|
||||
[](const torch::jit::Module& module,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
|
||||
auto detached_buffer =
|
||||
torch::jit::save_jit_module_to_bytes(module, _extra_files);
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(detached_buffer->data()),
|
||||
detached_buffer->size());
|
||||
});
|
||||
m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
py::dict result;
|
||||
mobile::ModuleInfo minfo =
|
||||
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
|
||||
result["bytecode_version"] = minfo.bytecode_version;
|
||||
result["operator_version"] = minfo.operator_version;
|
||||
result["function_names"] = minfo.function_names;
|
||||
result["type_names"] = minfo.type_names;
|
||||
result["opname_to_num_args"] = minfo.opname_to_num_args;
|
||||
return result;
|
||||
});
|
||||
|
||||
initScriptDictBindings(module);
|
||||
initScriptListBindings(module);
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
#include <Python.h> // NOLINT
|
||||
|
||||
#ifdef _WIN32
|
||||
__declspec(dllimport)
|
||||
#endif
|
||||
extern PyObject* initModuleFlatbuffer(void);
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
#endif
|
||||
__attribute__((visibility("default"))) PyObject* PyInit__C_flatbuffer(void);
|
||||
#endif
|
||||
|
||||
PyMODINIT_FUNC PyInit__C_flatbuffer(void)
|
||||
{
|
||||
return initModuleFlatbuffer();
|
||||
}
|
@ -184,29 +184,12 @@ def validate_map_location(map_location=None):
|
||||
return map_location
|
||||
|
||||
|
||||
def get_ff_module():
|
||||
try:
|
||||
import torch._C_flatbuffer as ff
|
||||
|
||||
return ff
|
||||
except ImportError:
|
||||
print("Please include //caffe2:_C_flatbuffer as dependency.")
|
||||
raise
|
||||
|
||||
|
||||
def jit_module_from_flatbuffer(f):
|
||||
ff = get_ff_module()
|
||||
if isinstance(f, str):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||
if os.path.isdir(f):
|
||||
raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
|
||||
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
return wrap_cpp_module(ff._load_jit_module_from_file(f))
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
||||
else:
|
||||
return wrap_cpp_module(ff._load_jit_module_from_bytes(f.read()))
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
|
||||
|
||||
|
||||
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
||||
@ -252,12 +235,11 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
||||
if extra_files is None:
|
||||
extra_files = {}
|
||||
|
||||
ff = get_ff_module()
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
ff._save_jit_module(m._c, f, extra_files)
|
||||
torch._C._save_jit_module(m._c, f, extra_files)
|
||||
else:
|
||||
s = ff._save_jit_module_to_bytes(m._c, extra_files)
|
||||
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
|
||||
f.write(s)
|
||||
|
||||
|
||||
@ -282,10 +264,9 @@ def get_flatbuffer_module_info(path_or_file):
|
||||
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
|
||||
}
|
||||
"""
|
||||
ff = get_ff_module()
|
||||
if isinstance(path_or_file, (str, pathlib.Path)):
|
||||
with open(path_or_file, "rb") as f:
|
||||
all_bytes = f.read()
|
||||
else:
|
||||
all_bytes = path_or_file.read()
|
||||
return ff._get_module_info_from_flatbuffer(all_bytes)
|
||||
return torch._C._get_module_info_from_flatbuffer(all_bytes)
|
||||
|
Reference in New Issue
Block a user