mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add fake impl for aten.unique2 (#124306)
Reapply of: https://github.com/pytorch/pytorch/pull/121571 Differential Revision: [D56258431](https://our.internmc.facebook.com/intern/diff/D56258431) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124306 Approved by: https://github.com/gmagogsfm
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc18afa25f
commit
d23bf9cef0
124
test/test_ops.py
124
test/test_ops.py
@ -2374,6 +2374,15 @@ dynamic_output_op_tests = (
|
||||
"linalg.lstsq.grad_oriented",
|
||||
)
|
||||
|
||||
# Ops that have dynamic output shapes that we can handle when
|
||||
# allow_dynamic_shape_ops is True in fake tensor shape environment.
|
||||
supported_dynamic_output_op_tests = (
|
||||
"nonzero",
|
||||
"unique",
|
||||
"repeat_interleave",
|
||||
"masked_select",
|
||||
)
|
||||
|
||||
# some inputs invoke dynamic output shape operators, some do not
|
||||
sometimes_dynamic_output_op_test = (
|
||||
"__getitem__",
|
||||
@ -2442,12 +2451,28 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
try:
|
||||
mode = FakeTensorMode()
|
||||
mode = FakeTensorMode()
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
allow_dynamic_output_shape_shape_env = ShapeEnv(
|
||||
allow_dynamic_output_shape_ops=True
|
||||
)
|
||||
|
||||
allow_dynamic_output_shape_mode = FakeTensorMode(
|
||||
shape_env=allow_dynamic_output_shape_shape_env
|
||||
)
|
||||
|
||||
try:
|
||||
with context():
|
||||
res = op(sample.input, *sample.args, **sample.kwargs)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def run_with_fake_mode_and_verify(fake_mode, match_results=True):
|
||||
def map_to_fake(e):
|
||||
if isinstance(e, torch.Tensor):
|
||||
return mode.from_tensor(e)
|
||||
return fake_mode.from_tensor(e)
|
||||
else:
|
||||
return e
|
||||
|
||||
@ -2457,56 +2482,65 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
try:
|
||||
with context():
|
||||
res = op(sample.input, *sample.args, **sample.kwargs)
|
||||
except Exception as e:
|
||||
continue
|
||||
with fake_mode:
|
||||
res_fake = op(input, *args, **kwargs)
|
||||
|
||||
with context():
|
||||
with mode:
|
||||
res_fake = op(input, *args, **kwargs)
|
||||
if not match_results:
|
||||
return
|
||||
|
||||
for fake_out, real_out in zip(
|
||||
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
|
||||
):
|
||||
if not isinstance(fake_out, torch.Tensor):
|
||||
self.assertTrue(not isinstance(real_out, torch.Tensor))
|
||||
self.assertEqual(fake_out, real_out)
|
||||
continue
|
||||
for fake_out, real_out in zip(
|
||||
pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
|
||||
):
|
||||
if not isinstance(fake_out, torch.Tensor):
|
||||
self.assertTrue(not isinstance(real_out, torch.Tensor))
|
||||
self.assertEqual(fake_out, real_out)
|
||||
continue
|
||||
|
||||
self.assertTrue(isinstance(fake_out, FakeTensor))
|
||||
# if you see a shape exception here, you may need to add
|
||||
# a `dynamic_output_shape` tag to an operator
|
||||
self.assertTrue(isinstance(fake_out, FakeTensor))
|
||||
# if you see a shape exception here, you may need to add
|
||||
# a `dynamic_output_shape` tag to an operator
|
||||
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
|
||||
if name not in aliasing_failures:
|
||||
fake_aliasing = outputs_alias_inputs(
|
||||
(input, args, kwargs), res_fake
|
||||
)
|
||||
real_aliasing = outputs_alias_inputs(
|
||||
(sample.input, sample, args, sample.kwargs), res
|
||||
)
|
||||
self.assertEqual(fake_aliasing, real_aliasing)
|
||||
if name not in aliasing_failures:
|
||||
fake_aliasing = outputs_alias_inputs(
|
||||
(input, args, kwargs), res_fake
|
||||
)
|
||||
real_aliasing = outputs_alias_inputs(
|
||||
(sample.input, sample, args, sample.kwargs), res
|
||||
)
|
||||
self.assertEqual(fake_aliasing, real_aliasing)
|
||||
|
||||
self.assertTrue(
|
||||
name not in dynamic_output_op_tests
|
||||
and name not in data_dependent_op_tests
|
||||
self.assertTrue(
|
||||
name not in dynamic_output_op_tests
|
||||
and name not in data_dependent_op_tests
|
||||
)
|
||||
|
||||
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
|
||||
self.assertTrue(
|
||||
name in dynamic_output_op_tests
|
||||
or name in sometimes_dynamic_output_op_test
|
||||
)
|
||||
self.assertTrue(
|
||||
mode.shape_env is None
|
||||
or not mode.shape_env.allow_dynamic_output_shape_ops
|
||||
or name not in supported_dynamic_output_op_tests
|
||||
)
|
||||
except torch._subclasses.fake_tensor.DataDependentOutputException:
|
||||
self.assertTrue(name in data_dependent_op_tests)
|
||||
|
||||
run_with_fake_mode_and_verify(mode)
|
||||
if name in supported_dynamic_output_op_tests:
|
||||
run_with_fake_mode_and_verify(
|
||||
allow_dynamic_output_shape_mode, match_results=False
|
||||
)
|
||||
|
||||
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
|
||||
pass
|
||||
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
|
||||
self.assertTrue(
|
||||
name in dynamic_output_op_tests
|
||||
or name in sometimes_dynamic_output_op_test
|
||||
)
|
||||
except torch._subclasses.fake_tensor.DataDependentOutputException:
|
||||
self.assertTrue(name in data_dependent_op_tests)
|
||||
|
||||
@ops(op_db, dtypes=OpDTypes.any_one)
|
||||
def test_pointwise_ops(self, device, dtype, op):
|
||||
name = op.name
|
||||
|
Reference in New Issue
Block a user