mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors. Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer. One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard): (1) Yolo require subclasses to use the API and hope users do as well (what this PR does) (2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation (3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses) (4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though? Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c28de2437
commit
4f34caf164
@ -19,6 +19,11 @@ from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatc
|
||||
from torch._custom_op.functional import register_functional_op
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
import logging
|
||||
import sys
|
||||
@ -2107,5 +2112,73 @@ class TestPythonDispatcher(TestCase):
|
||||
python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
|
||||
self.assertEqual(expected_shape, python_disp_shape)
|
||||
|
||||
class TestWrapperSubclassAliasing(TestCase):
|
||||
|
||||
def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
|
||||
def to_subclass(t: torch.Tensor):
|
||||
return TwoTensor(t, t.clone())
|
||||
|
||||
result_ref = op(*args, **kwargs)
|
||||
|
||||
args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
|
||||
kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
|
||||
|
||||
result_test = op(*args_subclass, **kwargs_subclass)
|
||||
|
||||
args_ref_flat, _ = pytree.tree_flatten((args, kwargs))
|
||||
args_ref_flat_tensors = [x for x in args_ref_flat if isinstance(x, torch.Tensor)]
|
||||
|
||||
args_test_flat, _ = pytree.tree_flatten((args_subclass, kwargs_subclass))
|
||||
args_test_flat_tensors = [x for x in args_test_flat if isinstance(x, torch.Tensor)]
|
||||
|
||||
result_ref_flat, _ = pytree.tree_flatten(result_ref)
|
||||
result_ref_flat_tensors = [x for x in result_ref_flat if isinstance(x, torch.Tensor)]
|
||||
|
||||
result_test_flat, _ = pytree.tree_flatten(result_test)
|
||||
result_test_flat_tensors = [x for x in result_test_flat if isinstance(x, torch.Tensor)]
|
||||
|
||||
for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
|
||||
for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
|
||||
out_is_inpt = o_ref is a_ref
|
||||
if out_is_inpt:
|
||||
self.assertTrue(o_test is a_test)
|
||||
|
||||
out_aliases_inpt = StorageWeakRef(o_ref.untyped_storage()) == StorageWeakRef(a_ref.untyped_storage())
|
||||
if out_aliases_inpt:
|
||||
self.assertTrue(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
|
||||
else:
|
||||
self.assertFalse(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
|
||||
|
||||
# This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
|
||||
# a util for wrapper subclasses to promise correct aliasing behavior.
|
||||
# It's probably overkill to test every OpInfo,
|
||||
# so I picked a sampling of ops with representative schemas.
|
||||
@ops([op for op in op_db if op.name in [
|
||||
'mul', # out-of-place
|
||||
'cat', # out-of-place (TensorList input)
|
||||
'index', # out-of-place (Optional TensorList input)
|
||||
'mul_', # inplace
|
||||
'view', # view
|
||||
't_', # inplace-view
|
||||
'split', # view (multi-return)
|
||||
'native_batch_norm', # mutable op (returns outputs and mutates some inputs)
|
||||
]], allowed_dtypes=(torch.float,))
|
||||
def test_wrapper_subclass_aliasing(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
sample = first_sample(self, samples)
|
||||
args = (sample.input, *sample.args)
|
||||
kwargs = sample.kwargs
|
||||
self._test_wrapper_subclass_aliasing(op, args, kwargs)
|
||||
|
||||
@ops(custom_op_db, allowed_dtypes=(torch.float,))
|
||||
def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
sample = first_sample(self, samples)
|
||||
args = (sample.input, *sample.args)
|
||||
kwargs = sample.kwargs
|
||||
self._test_wrapper_subclass_aliasing(op, args, kwargs)
|
||||
|
||||
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user