Add fake impl for aten.unique2 (#124306)

Reapply of: https://github.com/pytorch/pytorch/pull/121571
Differential Revision: [D56258431](https://our.internmc.facebook.com/intern/diff/D56258431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124306
Approved by: https://github.com/gmagogsfm
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-04-17 11:59:20 -07:00
committed by PyTorch MergeBot
parent cc18afa25f
commit d23bf9cef0
3 changed files with 163 additions and 45 deletions

View File

@ -2374,6 +2374,15 @@ dynamic_output_op_tests = (
"linalg.lstsq.grad_oriented",
)
# Ops that have dynamic output shapes that we can handle when
# allow_dynamic_shape_ops is True in fake tensor shape environment.
supported_dynamic_output_op_tests = (
"nonzero",
"unique",
"repeat_interleave",
"masked_select",
)
# some inputs invoke dynamic output shape operators, some do not
sometimes_dynamic_output_op_test = (
"__getitem__",
@ -2442,12 +2451,28 @@ class TestFakeTensor(TestCase):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
try:
mode = FakeTensorMode()
mode = FakeTensorMode()
from torch.fx.experimental.symbolic_shapes import ShapeEnv
allow_dynamic_output_shape_shape_env = ShapeEnv(
allow_dynamic_output_shape_ops=True
)
allow_dynamic_output_shape_mode = FakeTensorMode(
shape_env=allow_dynamic_output_shape_shape_env
)
try:
with context():
res = op(sample.input, *sample.args, **sample.kwargs)
except Exception:
continue
def run_with_fake_mode_and_verify(fake_mode, match_results=True):
def map_to_fake(e):
if isinstance(e, torch.Tensor):
return mode.from_tensor(e)
return fake_mode.from_tensor(e)
else:
return e
@ -2457,56 +2482,65 @@ class TestFakeTensor(TestCase):
try:
with context():
res = op(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
continue
with fake_mode:
res_fake = op(input, *args, **kwargs)
with context():
with mode:
res_fake = op(input, *args, **kwargs)
if not match_results:
return
for fake_out, real_out in zip(
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
self.assertEqual(fake_out, real_out)
continue
for fake_out, real_out in zip(
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
self.assertEqual(fake_out, real_out)
continue
self.assertTrue(isinstance(fake_out, FakeTensor))
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
self.assertTrue(isinstance(fake_out, FakeTensor))
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, True)
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, True)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs(
(input, args, kwargs), res_fake
)
real_aliasing = outputs_alias_inputs(
(sample.input, sample, args, sample.kwargs), res
)
self.assertEqual(fake_aliasing, real_aliasing)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs(
(input, args, kwargs), res_fake
)
real_aliasing = outputs_alias_inputs(
(sample.input, sample, args, sample.kwargs), res
)
self.assertEqual(fake_aliasing, real_aliasing)
self.assertTrue(
name not in dynamic_output_op_tests
and name not in data_dependent_op_tests
self.assertTrue(
name not in dynamic_output_op_tests
and name not in data_dependent_op_tests
)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(
name in dynamic_output_op_tests
or name in sometimes_dynamic_output_op_test
)
self.assertTrue(
mode.shape_env is None
or not mode.shape_env.allow_dynamic_output_shape_ops
or name not in supported_dynamic_output_op_tests
)
except torch._subclasses.fake_tensor.DataDependentOutputException:
self.assertTrue(name in data_dependent_op_tests)
run_with_fake_mode_and_verify(mode)
if name in supported_dynamic_output_op_tests:
run_with_fake_mode_and_verify(
allow_dynamic_output_shape_mode, match_results=False
)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(
name in dynamic_output_op_tests
or name in sometimes_dynamic_output_op_test
)
except torch._subclasses.fake_tensor.DataDependentOutputException:
self.assertTrue(name in data_dependent_op_tests)
@ops(op_db, dtypes=OpDTypes.any_one)
def test_pointwise_ops(self, device, dtype, op):
name = op.name