[3/n] loading meta to device (#100495)

Summary: Make it possible to `torch.jit.load(model, device)` to a device when `model` contains weights that are on device `meta`. Just leave the `meta` weights on `meta`, and load the weights that can be loaded to the target device.

Reviewed By: singlaiiit, RoshanPAN, sayitmemory

Differential Revision: D45099145

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100495
Approved by: https://github.com/houseroad
This commit is contained in:
Janet Yang
2023-05-08 22:14:38 +00:00
committed by PyTorch MergeBot
parent bde7b81f34
commit 812cadf90a
2 changed files with 46 additions and 1 deletions

View File

@ -565,6 +565,50 @@ class TestSaveLoad(JitTestCase):
self.assertTrue(m_buffers["buffer"].is_meta)
self.assertTrue(m_loaded_buffers["buffer"].is_meta)
def test_save_load_meta_tensors_to_device(self):
"""
Check that when loading a module with meta tensors to device, the meta tensors
stay on meta, but non-meta tensors are set to the indicated device.
"""
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
m = Foo()
m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu")
# Check submodules.
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
{name for name, _ in m.named_modules()},
{name for name, _ in m_loaded.named_modules()},
)
# Check parameters.
m_params = dict(m.named_parameters())
m_loaded_params = dict(m_loaded.named_parameters())
self.assertEqual(len(m_params), len(m_loaded_params))
self.assertEqual(m_params, m_loaded_params)
# Check params and buffers that are/are not meta tensors
self.assertTrue(m_params["foo.weight"].is_meta)
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
self.assertTrue(m_params["foo.bias"].is_meta)
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
self.assertTrue(m_params["bar.weight"].is_cpu)
self.assertTrue(m_loaded_params["bar.weight"].is_cpu)
self.assertTrue(m_params["bar.bias"].is_cpu)
self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
def test_save_load_with_saved_traced_inputs(self):
"""
Check that saving and loading with traced inputs works as expected

View File

@ -532,7 +532,8 @@ PickleOpCode Unpickler::readInstruction() {
const std::string& key = args.at(2).toStringRef();
at::Device device(args.at(3).toStringRef());
if (device_) {
// remap device location if it's not meta
if (device_ && !device.is_meta()) {
device = *device_;
}