[BE] Move flatbuffer related python C bindings to script_init (#97476)

Summary:
Extra C binding module for flatbuffer was introduced because
not all dependencies of Pytorch want (or can) bundle in flatbuffer.

However, flatbuffer is in by default now so this separate binding is not longer needed.

Test Plan: existing unit tests

Differential Revision: D44352583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97476
Approved by: https://github.com/dbort
This commit is contained in:
Han Qi (qihqi)
2023-03-28 17:56:32 +00:00
committed by PyTorch MergeBot
parent d8cc8ffebc
commit b895a0a675
8 changed files with 109 additions and 192 deletions

View File

@ -923,16 +923,7 @@ def configure_extension_build():
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
C_flatbuffer = Extension("torch._C_flatbuffer",
libraries=main_libraries,
sources=["torch/csrc/stub_with_flatbuffer.c"],
language='c',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
extensions.append(C)
extensions.append(C_flatbuffer)
# These extensions are built by cmake and copied manually in build_extensions()
# inside the build_ext implementation
@ -1066,7 +1057,6 @@ def main():
'bin/*',
'test/*',
'_C/*.pyi',
'_C_flatbuffer/*.pyi',
'cuda/*.pyi',
'optim/*.pyi',
'autograd/*.pyi',

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch._C_flatbuffer
import torch._C
from torch.ao.quantization import (
default_dynamic_qconfig,
@ -443,8 +443,8 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))

View File

@ -44,9 +44,6 @@ set(TORCH_PYTHON_SRCS
)
append_filelist("libtorch_python_core_sources" TORCH_PYTHON_SRCS)
list(APPEND TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp)
# NB: This has to match the condition under which the JIT test directory
# is included (at the time of writing that's in caffe2/CMakeLists.txt).
if(BUILD_TEST)
@ -327,9 +324,6 @@ set_source_files_properties(
# Disable certain warnings for GCC-9.X
if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0))
set_source_files_properties(${TORCH_SRC_DIR}/csrc/Module.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp
PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
set_source_files_properties(${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
endif()

View File

@ -1839,3 +1839,13 @@ class CapturedTraceback:
pass
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...
def _load_mobile_module_from_file(filename: str): ...
def _load_mobile_module_from_bytes(bytes_: bytes): ...
def _load_jit_module_from_file(filename: str): ...
def _load_jit_module_from_bytes(bytes_: bytes): ...
def _save_mobile_module(m: LiteScriptModule, filename: str): ...
def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ...
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
def _get_module_info_from_flatbuffer(data: bytes): ...

View File

@ -1,131 +0,0 @@
#include <torch/csrc/python_headers.h>
#include <libshm.h>
#include <cstdlib>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/csrc/utils/pybind.h>
#include <Python.h> // NOLINT
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/import.h>
namespace py = pybind11;
using torch::jit::kFlatbufferDataAlignmentBytes;
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
}
extern "C"
#ifdef _WIN32
__declspec(dllexport)
#endif
PyObject* initModuleFlatbuffer() {
using namespace torch::jit;
PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch._C_flatbuffer",
nullptr,
-1,
m,
}; // NOLINT
PyObject* module = PyModule_Create(&torchmodule);
auto pym = py::handle(module).cast<py::module>();
pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
return torch::jit::load_mobile_module_from_file(filename);
});
pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
return torch::jit::parse_and_initialize_mobile_module(
bytes_copy, bytes.size());
});
pym.def("_load_jit_module_from_file", [](const std::string& filename) {
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::load_jit_module_from_file(filename, extra_files);
});
pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::parse_and_initialize_jit_module(
bytes_copy, bytes.size(), extra_files);
});
pym.def(
"_save_mobile_module",
[](const torch::jit::mobile::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_mobile_module(module, filename, _extra_files);
});
pym.def(
"_save_jit_module",
[](const torch::jit::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_jit_module(module, filename, _extra_files);
});
pym.def(
"_save_mobile_module_to_bytes",
[](const torch::jit::mobile::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_save_jit_module_to_bytes",
[](const torch::jit::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_jit_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo =
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
return module;
}

View File

@ -14,6 +14,7 @@
#include <torch/csrc/jit/mobile/compatibility/backport.h>
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h>
@ -25,6 +26,7 @@
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>
@ -706,6 +708,30 @@ void pyCompilationUnitDefine(
}
}
// This function will copy bytes into a shared_ptr of chars aligned
// at kFlatbufferDataAlignmentBytes boundary (currently 16).
// This is required because tensors need to be aligned at 16 bytes boundary.
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
}
void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
@ -2327,6 +2353,71 @@ void initJitScriptBindings(PyObject* module) {
_save_parameters(map, filename, use_flatbuffer);
});
m.def("_load_mobile_module_from_file", [](const std::string& filename) {
return torch::jit::load_mobile_module_from_file(filename);
});
m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
return torch::jit::parse_and_initialize_mobile_module(
bytes_copy, bytes.size());
});
m.def("_load_jit_module_from_file", [](const std::string& filename) {
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::load_jit_module_from_file(filename, extra_files);
});
m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::parse_and_initialize_jit_module(
bytes_copy, bytes.size(), extra_files);
});
m.def(
"_save_mobile_module",
[](const torch::jit::mobile::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_mobile_module(module, filename, _extra_files);
});
m.def(
"_save_jit_module",
[](const torch::jit::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_jit_module(module, filename, _extra_files);
});
m.def(
"_save_mobile_module_to_bytes",
[](const torch::jit::mobile::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
m.def(
"_save_jit_module_to_bytes",
[](const torch::jit::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_jit_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo =
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
initScriptDictBindings(module);
initScriptListBindings(module);
}

View File

@ -1,18 +0,0 @@
#include <Python.h> // NOLINT
#ifdef _WIN32
__declspec(dllimport)
#endif
extern PyObject* initModuleFlatbuffer(void);
#ifndef _WIN32
#ifdef __cplusplus
extern "C"
#endif
__attribute__((visibility("default"))) PyObject* PyInit__C_flatbuffer(void);
#endif
PyMODINIT_FUNC PyInit__C_flatbuffer(void)
{
return initModuleFlatbuffer();
}

View File

@ -184,29 +184,12 @@ def validate_map_location(map_location=None):
return map_location
def get_ff_module():
try:
import torch._C_flatbuffer as ff
return ff
except ImportError:
print("Please include //caffe2:_C_flatbuffer as dependency.")
raise
def jit_module_from_flatbuffer(f):
ff = get_ff_module()
if isinstance(f, str):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
if os.path.isdir(f):
raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
if isinstance(f, (str, pathlib.Path)):
f = str(f)
return wrap_cpp_module(ff._load_jit_module_from_file(f))
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
else:
return wrap_cpp_module(ff._load_jit_module_from_bytes(f.read()))
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
@ -252,12 +235,11 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
if extra_files is None:
extra_files = {}
ff = get_ff_module()
if isinstance(f, (str, pathlib.Path)):
f = str(f)
ff._save_jit_module(m._c, f, extra_files)
torch._C._save_jit_module(m._c, f, extra_files)
else:
s = ff._save_jit_module_to_bytes(m._c, extra_files)
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
f.write(s)
@ -282,10 +264,9 @@ def get_flatbuffer_module_info(path_or_file):
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
}
"""
ff = get_ff_module()
if isinstance(path_or_file, (str, pathlib.Path)):
with open(path_or_file, "rb") as f:
all_bytes = f.read()
else:
all_bytes = path_or_file.read()
return ff._get_module_info_from_flatbuffer(all_bytes)
return torch._C._get_module_info_from_flatbuffer(all_bytes)