mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[3/n] loading meta to device (#100495)
Summary: Make it possible to `torch.jit.load(model, device)` to a device when `model` contains weights that are on device `meta`. Just leave the `meta` weights on `meta`, and load the weights that can be loaded to the target device. Reviewed By: singlaiiit, RoshanPAN, sayitmemory Differential Revision: D45099145 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100495 Approved by: https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
bde7b81f34
commit
812cadf90a
@ -565,6 +565,50 @@ class TestSaveLoad(JitTestCase):
|
||||
self.assertTrue(m_buffers["buffer"].is_meta)
|
||||
self.assertTrue(m_loaded_buffers["buffer"].is_meta)
|
||||
|
||||
def test_save_load_meta_tensors_to_device(self):
|
||||
"""
|
||||
Check that when loading a module with meta tensors to device, the meta tensors
|
||||
stay on meta, but non-meta tensors are set to the indicated device.
|
||||
"""
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Linear(2, 3, device="meta")
|
||||
self.bar = torch.nn.Linear(3, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo(x)
|
||||
x = self.bar(x)
|
||||
return x
|
||||
|
||||
m = Foo()
|
||||
|
||||
m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu")
|
||||
# Check submodules.
|
||||
self.assertEqual(
|
||||
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
|
||||
)
|
||||
self.assertEqual(
|
||||
{name for name, _ in m.named_modules()},
|
||||
{name for name, _ in m_loaded.named_modules()},
|
||||
)
|
||||
# Check parameters.
|
||||
m_params = dict(m.named_parameters())
|
||||
m_loaded_params = dict(m_loaded.named_parameters())
|
||||
self.assertEqual(len(m_params), len(m_loaded_params))
|
||||
self.assertEqual(m_params, m_loaded_params)
|
||||
# Check params and buffers that are/are not meta tensors
|
||||
self.assertTrue(m_params["foo.weight"].is_meta)
|
||||
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
|
||||
self.assertTrue(m_params["foo.bias"].is_meta)
|
||||
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
|
||||
self.assertTrue(m_params["bar.weight"].is_cpu)
|
||||
self.assertTrue(m_loaded_params["bar.weight"].is_cpu)
|
||||
self.assertTrue(m_params["bar.bias"].is_cpu)
|
||||
self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
|
||||
|
||||
|
||||
def test_save_load_with_saved_traced_inputs(self):
|
||||
"""
|
||||
Check that saving and loading with traced inputs works as expected
|
||||
|
@ -532,7 +532,8 @@ PickleOpCode Unpickler::readInstruction() {
|
||||
const std::string& key = args.at(2).toStringRef();
|
||||
|
||||
at::Device device(args.at(3).toStringRef());
|
||||
if (device_) {
|
||||
// remap device location if it's not meta
|
||||
if (device_ && !device.is_meta()) {
|
||||
device = *device_;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user