Add more error checking in subclass creation (#64746)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64746

This extracts the error checking that used to be in the PR above.
We are not going to land the proposed fix there, but I think we want this error checking in right now as these would lead to respectively a memory leak and arbitrary memory read/write.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30867569

Pulled By: albanD

fbshipit-source-id: bf468033fb8b49fcb26eed423f5fad82b4a46c56
This commit is contained in:
Alban Desmaison
2021-09-10 13:07:37 -07:00
committed by Facebook GitHub Bot
parent 89f94fc15f
commit d8ae3cc318
3 changed files with 54 additions and 1 deletions

View File

@ -23,6 +23,7 @@ def no_dispatch() -> Iterator[None]:
# 3. Enter dispatcher, wind your way through Autograd
# 4. Hit Python dispatch key, call __torch_dispatch__
WRAPPER_DEVICE = "meta"
# TODO: TensorBase should work
class LoggingTensor(torch.Tensor):
elem: torch.Tensor
@ -34,7 +35,7 @@ class LoggingTensor(torch.Tensor):
# The wrapping tensor (LoggingTensor) is just a meta tensor, so it
# doesn't hold any memory (meta tensor is generally the preferred type
# of tensor you want to make a subclass from)...
r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
r = torch.Tensor._make_subclass(cls, elem.to(WRAPPER_DEVICE), elem.requires_grad)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
@ -335,6 +336,38 @@ $4 = torch._ops.aten.mul($3, tensor(2))
$5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''')
def test_subclass_creation(self):
# Make sure these statements runs without error
# In particular checking that when internal detach returns
# subclasses, these are cleanly overwritten.
class Foo(torch.Tensor):
pass
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(RuntimeError, err_msg):
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
# And in case where we don't know if the user wants this subclass
# overwritten, raise a nice error.
# The standard LoggingTensor will fail because it is not on the right device
with self.assertRaisesRegex(TypeError, "expected.*device=cpu.*device=meta"):
Foo(LoggingTensor(torch.rand(2)))
# And if we put it on the right device, we still get a nice error
try:
global WRAPPER_DEVICE
prev_device = WRAPPER_DEVICE
WRAPPER_DEVICE = "cpu"
err_msg = "Creating a new Tensor subclass Foo.*python object of type LoggingTensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
Foo(LoggingTensor(torch.rand(2)))
finally:
WRAPPER_DEVICE = prev_device
if __name__ == '__main__':
run_tests()