add return_and_correct_aliasing() util for wrapper subclasses (#107915)

This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.

Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.

One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):

(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)

(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation

(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)

(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?

Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2023-08-28 19:43:08 -07:00
committed by PyTorch MergeBot
parent 6c28de2437
commit 4f34caf164
4 changed files with 298 additions and 9 deletions

View File

@ -19,6 +19,11 @@ from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatc
from torch._custom_op.functional import register_functional_op
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.multiprocessing.reductions import StorageWeakRef
import logging
import sys
@ -2107,5 +2112,73 @@ class TestPythonDispatcher(TestCase):
python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
self.assertEqual(expected_shape, python_disp_shape)
class TestWrapperSubclassAliasing(TestCase):
def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
def to_subclass(t: torch.Tensor):
return TwoTensor(t, t.clone())
result_ref = op(*args, **kwargs)
args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
result_test = op(*args_subclass, **kwargs_subclass)
args_ref_flat, _ = pytree.tree_flatten((args, kwargs))
args_ref_flat_tensors = [x for x in args_ref_flat if isinstance(x, torch.Tensor)]
args_test_flat, _ = pytree.tree_flatten((args_subclass, kwargs_subclass))
args_test_flat_tensors = [x for x in args_test_flat if isinstance(x, torch.Tensor)]
result_ref_flat, _ = pytree.tree_flatten(result_ref)
result_ref_flat_tensors = [x for x in result_ref_flat if isinstance(x, torch.Tensor)]
result_test_flat, _ = pytree.tree_flatten(result_test)
result_test_flat_tensors = [x for x in result_test_flat if isinstance(x, torch.Tensor)]
for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
out_is_inpt = o_ref is a_ref
if out_is_inpt:
self.assertTrue(o_test is a_test)
out_aliases_inpt = StorageWeakRef(o_ref.untyped_storage()) == StorageWeakRef(a_ref.untyped_storage())
if out_aliases_inpt:
self.assertTrue(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
else:
self.assertFalse(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
# This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
# a util for wrapper subclasses to promise correct aliasing behavior.
# It's probably overkill to test every OpInfo,
# so I picked a sampling of ops with representative schemas.
@ops([op for op in op_db if op.name in [
'mul', # out-of-place
'cat', # out-of-place (TensorList input)
'index', # out-of-place (Optional TensorList input)
'mul_', # inplace
'view', # view
't_', # inplace-view
'split', # view (multi-return)
'native_batch_norm', # mutable op (returns outputs and mutates some inputs)
]], allowed_dtypes=(torch.float,))
def test_wrapper_subclass_aliasing(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
@ops(custom_op_db, allowed_dtypes=(torch.float,))
def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
if __name__ == '__main__':
run_tests()