mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-25 16:14:55 +08:00
Add more error checking in subclass creation (#64746)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64746 This extracts the error checking that used to be in the PR above. We are not going to land the proposed fix there, but I think we want this error checking in right now as these would lead to respectively a memory leak and arbitrary memory read/write. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D30867569 Pulled By: albanD fbshipit-source-id: bf468033fb8b49fcb26eed423f5fad82b4a46c56
This commit is contained in:
committed by
Facebook GitHub Bot
parent
89f94fc15f
commit
d8ae3cc318
@ -23,6 +23,7 @@ def no_dispatch() -> Iterator[None]:
|
||||
# 3. Enter dispatcher, wind your way through Autograd
|
||||
# 4. Hit Python dispatch key, call __torch_dispatch__
|
||||
|
||||
WRAPPER_DEVICE = "meta"
|
||||
# TODO: TensorBase should work
|
||||
class LoggingTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
@ -34,7 +35,7 @@ class LoggingTensor(torch.Tensor):
|
||||
# The wrapping tensor (LoggingTensor) is just a meta tensor, so it
|
||||
# doesn't hold any memory (meta tensor is generally the preferred type
|
||||
# of tensor you want to make a subclass from)...
|
||||
r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
|
||||
r = torch.Tensor._make_subclass(cls, elem.to(WRAPPER_DEVICE), elem.requires_grad)
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
r.elem = elem
|
||||
return r
|
||||
@ -335,6 +336,38 @@ $4 = torch._ops.aten.mul($3, tensor(2))
|
||||
$5 = torch._ops.aten.mul($4, $0)
|
||||
$6 = torch._ops.aten.add_($1, $5)''')
|
||||
|
||||
def test_subclass_creation(self):
|
||||
# Make sure these statements runs without error
|
||||
# In particular checking that when internal detach returns
|
||||
# subclasses, these are cleanly overwritten.
|
||||
class Foo(torch.Tensor):
|
||||
pass
|
||||
|
||||
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
|
||||
|
||||
# And in case where we don't know if the user wants this subclass
|
||||
# overwritten, raise a nice error.
|
||||
# The standard LoggingTensor will fail because it is not on the right device
|
||||
with self.assertRaisesRegex(TypeError, "expected.*device=cpu.*device=meta"):
|
||||
Foo(LoggingTensor(torch.rand(2)))
|
||||
|
||||
# And if we put it on the right device, we still get a nice error
|
||||
try:
|
||||
global WRAPPER_DEVICE
|
||||
prev_device = WRAPPER_DEVICE
|
||||
WRAPPER_DEVICE = "cpu"
|
||||
|
||||
err_msg = "Creating a new Tensor subclass Foo.*python object of type LoggingTensor"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
Foo(LoggingTensor(torch.rand(2)))
|
||||
|
||||
finally:
|
||||
WRAPPER_DEVICE = prev_device
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user