wrapper subclasses: support non-cpu device for dynamic shape overload (#107926)

This is tested by AOTAutograd later in the stack, but I can add direct tests if anyone wants them.

Previously, the second overload of `_make_wrapper_subclass` (which supports dynamic shapes) would just always return a wrapper tensor that reported as being on `cpu`. This updates it to properly respect the `device` arg that was passed in.

At first I thought about doing this the same way that FakeTensor does it (override device to do a custom impl), but that seemed overly complicated. Since the subclass is a wrapper, we can just bake in the value on the wrapper.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107926
Approved by: https://github.com/ezyang
ghstack dependencies: #107915, #107916
This commit is contained in:
Brian Hirsh
2023-08-28 19:43:09 -07:00
committed by PyTorch MergeBot
parent c6e3adaf54
commit 80c7fdf49f

View File

@ -702,7 +702,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
Storage storage{
Storage::use_byte_size_t{},
0,
at::DataPtr{},
at::DataPtr{nullptr, r.device(7)},
/*allocator=*/c10::GetAllocator(c10::kMeta),
/*resizeable=*/true};