Add fake impl for aten.unique2 (#124306)

Reapply of: https://github.com/pytorch/pytorch/pull/121571
Differential Revision: [D56258431](https://our.internmc.facebook.com/intern/diff/D56258431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124306
Approved by: https://github.com/gmagogsfm
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-04-17 11:59:20 -07:00
committed by PyTorch MergeBot
parent cc18afa25f
commit d23bf9cef0
3 changed files with 163 additions and 45 deletions

View File

@ -2374,6 +2374,15 @@ dynamic_output_op_tests = (
"linalg.lstsq.grad_oriented",
)
# Ops that have dynamic output shapes that we can handle when
# allow_dynamic_shape_ops is True in fake tensor shape environment.
supported_dynamic_output_op_tests = (
"nonzero",
"unique",
"repeat_interleave",
"masked_select",
)
# some inputs invoke dynamic output shape operators, some do not
sometimes_dynamic_output_op_test = (
"__getitem__",
@ -2442,12 +2451,28 @@ class TestFakeTensor(TestCase):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
try:
mode = FakeTensorMode()
mode = FakeTensorMode()
from torch.fx.experimental.symbolic_shapes import ShapeEnv
allow_dynamic_output_shape_shape_env = ShapeEnv(
allow_dynamic_output_shape_ops=True
)
allow_dynamic_output_shape_mode = FakeTensorMode(
shape_env=allow_dynamic_output_shape_shape_env
)
try:
with context():
res = op(sample.input, *sample.args, **sample.kwargs)
except Exception:
continue
def run_with_fake_mode_and_verify(fake_mode, match_results=True):
def map_to_fake(e):
if isinstance(e, torch.Tensor):
return mode.from_tensor(e)
return fake_mode.from_tensor(e)
else:
return e
@ -2457,56 +2482,65 @@ class TestFakeTensor(TestCase):
try:
with context():
res = op(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
continue
with fake_mode:
res_fake = op(input, *args, **kwargs)
with context():
with mode:
res_fake = op(input, *args, **kwargs)
if not match_results:
return
for fake_out, real_out in zip(
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
self.assertEqual(fake_out, real_out)
continue
for fake_out, real_out in zip(
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
self.assertEqual(fake_out, real_out)
continue
self.assertTrue(isinstance(fake_out, FakeTensor))
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
self.assertTrue(isinstance(fake_out, FakeTensor))
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, True)
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, True)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs(
(input, args, kwargs), res_fake
)
real_aliasing = outputs_alias_inputs(
(sample.input, sample, args, sample.kwargs), res
)
self.assertEqual(fake_aliasing, real_aliasing)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs(
(input, args, kwargs), res_fake
)
real_aliasing = outputs_alias_inputs(
(sample.input, sample, args, sample.kwargs), res
)
self.assertEqual(fake_aliasing, real_aliasing)
self.assertTrue(
name not in dynamic_output_op_tests
and name not in data_dependent_op_tests
self.assertTrue(
name not in dynamic_output_op_tests
and name not in data_dependent_op_tests
)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(
name in dynamic_output_op_tests
or name in sometimes_dynamic_output_op_test
)
self.assertTrue(
mode.shape_env is None
or not mode.shape_env.allow_dynamic_output_shape_ops
or name not in supported_dynamic_output_op_tests
)
except torch._subclasses.fake_tensor.DataDependentOutputException:
self.assertTrue(name in data_dependent_op_tests)
run_with_fake_mode_and_verify(mode)
if name in supported_dynamic_output_op_tests:
run_with_fake_mode_and_verify(
allow_dynamic_output_shape_mode, match_results=False
)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(
name in dynamic_output_op_tests
or name in sometimes_dynamic_output_op_test
)
except torch._subclasses.fake_tensor.DataDependentOutputException:
self.assertTrue(name in data_dependent_op_tests)
@ops(op_db, dtypes=OpDTypes.any_one)
def test_pointwise_ops(self, device, dtype, op):
name = op.name

View File

@ -258,6 +258,62 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)
@register_op_impl(aten._unique2.default)
def unique2(
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if arg.unique_memo is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = arg._nonzero_memo = 0
arg._nonzero_memo_vc = arg._version
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
maxval = sys.maxsize - 1
if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())
_constrain_range_for_size(nnz, max=maxval)
arg.unique_memo = nnz
ret = [arg.new_empty((arg.unique_memo,))]
if return_inverse:
ret.append(torch.empty_like(arg))
else:
ret.append(arg.new_empty(0))
if return_counts:
ret.append(torch.empty_like(arg))
else:
ret.append(arg.new_empty(0))
return tuple(ret)
@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:

View File

@ -397,6 +397,31 @@ class FakeTensor(torch.Tensor):
return None
return self._nonzero_memo
# This memorizes the unbacked SymInt representing the number of unique
# elements in this tensor. This is helpful if you do something like
# calling torch.unique(x) multiple times and should
# give a consistent unbacked SymInt. It needs to be invalidated in the
# same way constant is.
# TODO: Generalize this as needed, e.g., into a trie of memos
_unique_memo: Optional[torch.SymInt]
_unique_memo_vc: Optional[int]
@property
def unique_memo(self):
if self._unique_memo is None:
return None
# Version counter based tracking isn't 100% sound but it's close
# enough
if self._unique_memo_vc != self._version:
self._unique_memo = None
return None
return self._unique_memo
@unique_memo.setter
def unique_memo(self, value):
self._unique_memo = value
self._unique_memo_vc = self._version
@property
def device(self):
if self.fake_mode.in_kernel_invocation:
@ -471,6 +496,9 @@ class FakeTensor(torch.Tensor):
self.constant = constant # type: ignore[attr-defined]
self._nonzero_memo = None # type: ignore[attr-defined]
self._nonzero_memo_vc = None # type: ignore[attr-defined]
self._unique_memo = None # type: ignore[attr-defined]
self._unique_memo_vc = None # type: ignore[attr-defined]
if FakeTensorConfig.debug:
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
return self