Add AotAutogradFallbackTests to dynamic suite (#100454)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100454
Approved by: https://github.com/ezyang
This commit is contained in:
Michael Voznesensky
2023-05-04 01:22:17 +00:00
committed by PyTorch MergeBot
parent 2dca418112
commit fe3ecfe0cf
5 changed files with 90 additions and 29 deletions

View File

@ -6,9 +6,26 @@ import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter, rand_strided
from torch._dynamo.utils import ifdyn, ifdynstaticdefault
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
def maybe_dupe_op(x):
y = x + 1
z = x + 2
if x.numel() < 5:
return y, y
else:
return y, z
aten = torch.ops.aten
lib = torch.library.Library("custom", "DEF")
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
def test_LSTM(self):
# https://github.com/pytorch/torchdynamo/issues/1147
@ -385,12 +402,13 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
fxy(x1, y1)
fxy(x2, y2)
self.assertTrue(failure_reason is None)
if not torch._dynamo.config.dynamic_shapes:
self.assertTrue(failure_reason is None)
# Reset failure reason
failure_reason = None
self.assertEqual(cc.frame_count, 1)
self.assertEqual(cc.frame_count, ifdyn(ifdynstaticdefault(1, 2), 1))
torch._dynamo.reset() # for new backend
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
@ -424,10 +442,19 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
f(a)
f(a)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
)
if (
torch._dynamo.config.dynamic_shapes
and not torch._dynamo.config.assume_static_by_default
):
self.assertExpectedInline(
failure_reason,
"""tensor 'L['a']' stride mismatch at index 1. expected 1, actual 3""",
)
else:
self.assertExpectedInline(
failure_reason,
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
)
torch._dynamo.reset()
@ -665,21 +692,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
@patch("torch._functorch.config.debug_assert", True)
@patch("torch._dynamo.config.dynamic_shapes", False)
def test_multiple_aot_autograd_calls_dupe_args(self):
def maybe_dupe_op(x):
y = x + 1
z = x + 2
if x.numel() < 5:
return y, y
else:
return y, z
aten = torch.ops.aten
lib = torch.library.Library("custom", "DEF")
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
# this is just dealing with the fact that
# aot_module_simplified expects submods to always return tuples/lists
class WrapperModule(torch.nn.Module):

View File

@ -4,6 +4,7 @@ from torch._dynamo.testing import make_test_cls_with_patches
try:
from . import (
test_aot_autograd,
test_ctx_manager,
test_export,
test_functions,
@ -14,6 +15,7 @@ try:
test_subgraphs,
)
except ImportError:
import test_aot_autograd
import test_ctx_manager
import test_export
import test_functions
@ -82,6 +84,7 @@ tests = [
test_export.ExportTests,
test_subgraphs.SubGraphTests,
test_higher_order_ops.HigherOrderOpTests,
test_aot_autograd.AotAutogradFallbackTests,
]
for test in tests:
make_dynamic_cls(test)

View File

@ -1577,9 +1577,6 @@ inplace_symbolic_tensor_failures = {
xfail('unique', ''),
# in-place has a different signature than out-of-place
xfail('uniform', ''),
# Views
xfail('t', ''),
xfail('transpose', ''),
}
# Copies inputs to inplace operations to avoid inplace modifications

View File

@ -1954,13 +1954,15 @@ def aot_wrapper_dedupe(
duped_arg_len = len(flat_args)
j = 0 # index into deduped_flat_args
for i, t in enumerate(flat_args):
if t in seen_args:
keep_arg_mask.append(False)
add_dupe_map.append(seen_args[t])
continue
for t in flat_args:
if isinstance(t, torch.Tensor):
if t in seen_args:
keep_arg_mask.append(False)
add_dupe_map.append(seen_args[t])
continue
seen_args[t] = j
keep_arg_mask.append(True)
seen_args[t] = j
add_dupe_map.append(j)
j += 1
assert len(add_dupe_map) == duped_arg_len, (

View File

@ -3103,6 +3103,51 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None):
return self.new_empty(result_size)
@register_meta(torch.ops.aten.transpose_)
def transpose_(self, dim0, dim1):
assert self.layout not in {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
ndims = self.ndim
dim0 = maybe_wrap_dim(dim0, ndims)
dim1 = maybe_wrap_dim(dim1, ndims)
if dim0 == dim1:
return self
size = list(self.size())
stride = list(self.stride())
stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
size[dim0], size[dim1] = size[dim1], size[dim0]
self.as_strided_(size, stride)
return self
@register_meta(torch.ops.aten.t_)
def t_(self):
ndims = self.ndim
if self.is_sparse:
sparse_dim = self.sparse_dim()
dense_dim = self.dense_dim()
assert (
sparse_dim <= 2 and dense_dim == 0
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950
else:
assert (
self.dim() <= 2
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
return transpose_(self, 0, 0 if ndims < 2 else 1)
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs