mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[torch.package/TorchScript] logic to enable sharing of tensors on load (#57573)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57573 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D28226975 Pulled By: Lilyjjo fbshipit-source-id: bc8cb3e8052fa18336c437e0601d8b0028fd1895
This commit is contained in:
committed by
Facebook GitHub Bot
parent
307375a88e
commit
9403fe17ce
@ -1057,6 +1057,36 @@ void initJITBindings(PyObject* module) {
|
||||
return self.getAllRecords();
|
||||
});
|
||||
|
||||
// Used by torch.Package to coordinate deserialization of storages across
|
||||
// ScriptModules and eager modules
|
||||
py::class_<StorageContext, std::shared_ptr<StorageContext>>(
|
||||
m, "StorageContext")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"get_storage",
|
||||
[](StorageContext& self,
|
||||
const std::string& name,
|
||||
py::object data_type_obj) {
|
||||
c10::Storage storage = self.getStorage(name);
|
||||
auto scalar_type =
|
||||
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
|
||||
auto ptr =
|
||||
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
|
||||
std::move(storage),
|
||||
at::DispatchKeySet(),
|
||||
at::CPU(scalar_type).typeMeta());
|
||||
|
||||
return at::Tensor(std::move(ptr));
|
||||
})
|
||||
.def(
|
||||
"add_storage",
|
||||
[](StorageContext& self,
|
||||
const std::string& name,
|
||||
const at::Tensor& tensor) {
|
||||
self.addStorage(name, tensor.storage());
|
||||
})
|
||||
.def("has_storage", &StorageContext::hasStorage);
|
||||
|
||||
m.def(
|
||||
"_jit_get_operation",
|
||||
[](const std::string& op_name) {
|
||||
|
Reference in New Issue
Block a user