mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
wrapper subclasses: support non-cpu device for dynamic shape overload (#107926)
This is tested by AOTAutograd later in the stack, but I can add direct tests if anyone wants them. Previously, the second overload of `_make_wrapper_subclass` (which supports dynamic shapes) would just always return a wrapper tensor that reported as being on `cpu`. This updates it to properly respect the `device` arg that was passed in. At first I thought about doing this the same way that FakeTensor does it (override device to do a custom impl), but that seemed overly complicated. Since the subclass is a wrapper, we can just bake in the value on the wrapper. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107926 Approved by: https://github.com/ezyang ghstack dependencies: #107915, #107916
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6e3adaf54
commit
80c7fdf49f
@ -702,7 +702,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
Storage storage{
|
||||
Storage::use_byte_size_t{},
|
||||
0,
|
||||
at::DataPtr{},
|
||||
at::DataPtr{nullptr, r.device(7)},
|
||||
/*allocator=*/c10::GetAllocator(c10::kMeta),
|
||||
/*resizeable=*/true};
|
||||
|
||||
|
Reference in New Issue
Block a user