mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add fake impl for aten.unique2 (#124306)
Reapply of: https://github.com/pytorch/pytorch/pull/121571 Differential Revision: [D56258431](https://our.internmc.facebook.com/intern/diff/D56258431) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124306 Approved by: https://github.com/gmagogsfm
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc18afa25f
commit
d23bf9cef0
124
test/test_ops.py
124
test/test_ops.py
@ -2374,6 +2374,15 @@ dynamic_output_op_tests = (
|
||||
"linalg.lstsq.grad_oriented",
|
||||
)
|
||||
|
||||
# Ops that have dynamic output shapes that we can handle when
|
||||
# allow_dynamic_shape_ops is True in fake tensor shape environment.
|
||||
supported_dynamic_output_op_tests = (
|
||||
"nonzero",
|
||||
"unique",
|
||||
"repeat_interleave",
|
||||
"masked_select",
|
||||
)
|
||||
|
||||
# some inputs invoke dynamic output shape operators, some do not
|
||||
sometimes_dynamic_output_op_test = (
|
||||
"__getitem__",
|
||||
@ -2442,12 +2451,28 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
try:
|
||||
mode = FakeTensorMode()
|
||||
mode = FakeTensorMode()
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
allow_dynamic_output_shape_shape_env = ShapeEnv(
|
||||
allow_dynamic_output_shape_ops=True
|
||||
)
|
||||
|
||||
allow_dynamic_output_shape_mode = FakeTensorMode(
|
||||
shape_env=allow_dynamic_output_shape_shape_env
|
||||
)
|
||||
|
||||
try:
|
||||
with context():
|
||||
res = op(sample.input, *sample.args, **sample.kwargs)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def run_with_fake_mode_and_verify(fake_mode, match_results=True):
|
||||
def map_to_fake(e):
|
||||
if isinstance(e, torch.Tensor):
|
||||
return mode.from_tensor(e)
|
||||
return fake_mode.from_tensor(e)
|
||||
else:
|
||||
return e
|
||||
|
||||
@ -2457,56 +2482,65 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
try:
|
||||
with context():
|
||||
res = op(sample.input, *sample.args, **sample.kwargs)
|
||||
except Exception as e:
|
||||
continue
|
||||
with fake_mode:
|
||||
res_fake = op(input, *args, **kwargs)
|
||||
|
||||
with context():
|
||||
with mode:
|
||||
res_fake = op(input, *args, **kwargs)
|
||||
if not match_results:
|
||||
return
|
||||
|
||||
for fake_out, real_out in zip(
|
||||
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
|
||||
):
|
||||
if not isinstance(fake_out, torch.Tensor):
|
||||
self.assertTrue(not isinstance(real_out, torch.Tensor))
|
||||
self.assertEqual(fake_out, real_out)
|
||||
continue
|
||||
for fake_out, real_out in zip(
|
||||
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
|
||||
):
|
||||
if not isinstance(fake_out, torch.Tensor):
|
||||
self.assertTrue(not isinstance(real_out, torch.Tensor))
|
||||
self.assertEqual(fake_out, real_out)
|
||||
continue
|
||||
|
||||
self.assertTrue(isinstance(fake_out, FakeTensor))
|
||||
# if you see a shape exception here, you may need to add
|
||||
# a `dynamic_output_shape` tag to an operator
|
||||
self.assertTrue(isinstance(fake_out, FakeTensor))
|
||||
# if you see a shape exception here, you may need to add
|
||||
# a `dynamic_output_shape` tag to an operator
|
||||
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
|
||||
if name not in aliasing_failures:
|
||||
fake_aliasing = outputs_alias_inputs(
|
||||
(input, args, kwargs), res_fake
|
||||
)
|
||||
real_aliasing = outputs_alias_inputs(
|
||||
(sample.input, sample, args, sample.kwargs), res
|
||||
)
|
||||
self.assertEqual(fake_aliasing, real_aliasing)
|
||||
if name not in aliasing_failures:
|
||||
fake_aliasing = outputs_alias_inputs(
|
||||
(input, args, kwargs), res_fake
|
||||
)
|
||||
real_aliasing = outputs_alias_inputs(
|
||||
(sample.input, sample, args, sample.kwargs), res
|
||||
)
|
||||
self.assertEqual(fake_aliasing, real_aliasing)
|
||||
|
||||
self.assertTrue(
|
||||
name not in dynamic_output_op_tests
|
||||
and name not in data_dependent_op_tests
|
||||
self.assertTrue(
|
||||
name not in dynamic_output_op_tests
|
||||
and name not in data_dependent_op_tests
|
||||
)
|
||||
|
||||
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
|
||||
self.assertTrue(
|
||||
name in dynamic_output_op_tests
|
||||
or name in sometimes_dynamic_output_op_test
|
||||
)
|
||||
self.assertTrue(
|
||||
mode.shape_env is None
|
||||
or not mode.shape_env.allow_dynamic_output_shape_ops
|
||||
or name not in supported_dynamic_output_op_tests
|
||||
)
|
||||
except torch._subclasses.fake_tensor.DataDependentOutputException:
|
||||
self.assertTrue(name in data_dependent_op_tests)
|
||||
|
||||
run_with_fake_mode_and_verify(mode)
|
||||
if name in supported_dynamic_output_op_tests:
|
||||
run_with_fake_mode_and_verify(
|
||||
allow_dynamic_output_shape_mode, match_results=False
|
||||
)
|
||||
|
||||
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
|
||||
self.assertTrue(
|
||||
name in dynamic_output_op_tests
|
||||
or name in sometimes_dynamic_output_op_test
|
||||
)
|
||||
except torch._subclasses.fake_tensor.DataDependentOutputException:
|
||||
self.assertTrue(name in data_dependent_op_tests)
|
||||
|
||||
@ops(op_db, dtypes=OpDTypes.any_one)
|
||||
def test_pointwise_ops(self, device, dtype, op):
|
||||
name = op.name
|
||||
|
||||
@ -258,6 +258,62 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
|
||||
@register_op_impl(aten._unique2.default)
|
||||
def unique2(
|
||||
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
|
||||
):
|
||||
if (
|
||||
fake_mode.shape_env is None
|
||||
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
||||
):
|
||||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
if arg.unique_memo is None:
|
||||
# Avoid importing sympy at a module level
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
has_free_symbols,
|
||||
)
|
||||
|
||||
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
|
||||
# If numel is zero, then the output size must be zero.
|
||||
# In this case, we must not allocate an unbacked SymInt,
|
||||
# because if we do, it will immediately get refined to
|
||||
# zero, but this will be inconsistent with size oblivious
|
||||
# tests (which will continue to claim that the unbacked
|
||||
# symint cannot equal zero). We could also unconditionally
|
||||
# allocate an unbacked SymInt and not refine its range,
|
||||
# but this seems more precise.
|
||||
nnz = arg._nonzero_memo = 0
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
else:
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
maxval = sys.maxsize - 1
|
||||
|
||||
if not has_free_symbols(arg.numel()):
|
||||
maxval = int(arg.numel())
|
||||
|
||||
_constrain_range_for_size(nnz, max=maxval)
|
||||
|
||||
arg.unique_memo = nnz
|
||||
|
||||
ret = [arg.new_empty((arg.unique_memo,))]
|
||||
|
||||
if return_inverse:
|
||||
ret.append(torch.empty_like(arg))
|
||||
else:
|
||||
ret.append(arg.new_empty(0))
|
||||
|
||||
if return_counts:
|
||||
ret.append(torch.empty_like(arg))
|
||||
else:
|
||||
ret.append(arg.new_empty(0))
|
||||
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
@register_op_impl(aten.repeat_interleave.Tensor)
|
||||
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
|
||||
if output_size is None:
|
||||
|
||||
@ -397,6 +397,31 @@ class FakeTensor(torch.Tensor):
|
||||
return None
|
||||
return self._nonzero_memo
|
||||
|
||||
# This memorizes the unbacked SymInt representing the number of unique
|
||||
# elements in this tensor. This is helpful if you do something like
|
||||
# calling torch.unique(x) multiple times and should
|
||||
# give a consistent unbacked SymInt. It needs to be invalidated in the
|
||||
# same way constant is.
|
||||
# TODO: Generalize this as needed, e.g., into a trie of memos
|
||||
_unique_memo: Optional[torch.SymInt]
|
||||
_unique_memo_vc: Optional[int]
|
||||
|
||||
@property
|
||||
def unique_memo(self):
|
||||
if self._unique_memo is None:
|
||||
return None
|
||||
# Version counter based tracking isn't 100% sound but it's close
|
||||
# enough
|
||||
if self._unique_memo_vc != self._version:
|
||||
self._unique_memo = None
|
||||
return None
|
||||
return self._unique_memo
|
||||
|
||||
@unique_memo.setter
|
||||
def unique_memo(self, value):
|
||||
self._unique_memo = value
|
||||
self._unique_memo_vc = self._version
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if self.fake_mode.in_kernel_invocation:
|
||||
@ -471,6 +496,9 @@ class FakeTensor(torch.Tensor):
|
||||
self.constant = constant # type: ignore[attr-defined]
|
||||
self._nonzero_memo = None # type: ignore[attr-defined]
|
||||
self._nonzero_memo_vc = None # type: ignore[attr-defined]
|
||||
self._unique_memo = None # type: ignore[attr-defined]
|
||||
self._unique_memo_vc = None # type: ignore[attr-defined]
|
||||
|
||||
if FakeTensorConfig.debug:
|
||||
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user