[torch.package/TorchScript] logic to enable sharing of tensors on load (#57573)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57573

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D28226975

Pulled By: Lilyjjo

fbshipit-source-id: bc8cb3e8052fa18336c437e0601d8b0028fd1895
This commit is contained in:
Lillian Johnson
2021-05-14 08:19:13 -07:00
committed by Facebook GitHub Bot
parent 307375a88e
commit 9403fe17ce
11 changed files with 144 additions and 26 deletions

View File

@ -1057,6 +1057,36 @@ void initJITBindings(PyObject* module) {
return self.getAllRecords();
});
// Used by torch.Package to coordinate deserialization of storages across
// ScriptModules and eager modules
py::class_<StorageContext, std::shared_ptr<StorageContext>>(
m, "StorageContext")
.def(py::init<>())
.def(
"get_storage",
[](StorageContext& self,
const std::string& name,
py::object data_type_obj) {
c10::Storage storage = self.getStorage(name);
auto scalar_type =
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
auto ptr =
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
std::move(storage),
at::DispatchKeySet(),
at::CPU(scalar_type).typeMeta());
return at::Tensor(std::move(ptr));
})
.def(
"add_storage",
[](StorageContext& self,
const std::string& name,
const at::Tensor& tensor) {
self.addStorage(name, tensor.storage());
})
.def("has_storage", &StorageContext::hasStorage);
m.def(
"_jit_get_operation",
[](const std::string& op_name) {