mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add AotAutogradFallbackTests to dynamic suite (#100454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100454 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2dca418112
commit
fe3ecfe0cf
@ -6,9 +6,26 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.testing import CompileCounter, rand_strided
|
||||
from torch._dynamo.utils import ifdyn, ifdynstaticdefault
|
||||
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
|
||||
|
||||
|
||||
def maybe_dupe_op(x):
|
||||
y = x + 1
|
||||
z = x + 2
|
||||
if x.numel() < 5:
|
||||
return y, y
|
||||
else:
|
||||
return y, z
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
lib = torch.library.Library("custom", "DEF")
|
||||
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
||||
|
||||
|
||||
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||
def test_LSTM(self):
|
||||
# https://github.com/pytorch/torchdynamo/issues/1147
|
||||
@ -385,12 +402,13 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||
fxy(x1, y1)
|
||||
fxy(x2, y2)
|
||||
|
||||
self.assertTrue(failure_reason is None)
|
||||
if not torch._dynamo.config.dynamic_shapes:
|
||||
self.assertTrue(failure_reason is None)
|
||||
|
||||
# Reset failure reason
|
||||
failure_reason = None
|
||||
|
||||
self.assertEqual(cc.frame_count, 1)
|
||||
self.assertEqual(cc.frame_count, ifdyn(ifdynstaticdefault(1, 2), 1))
|
||||
|
||||
torch._dynamo.reset() # for new backend
|
||||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
@ -424,10 +442,19 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||
f(a)
|
||||
f(a)
|
||||
self.assertEqual(cc.frame_count, 2)
|
||||
self.assertExpectedInline(
|
||||
failure_reason,
|
||||
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
|
||||
)
|
||||
if (
|
||||
torch._dynamo.config.dynamic_shapes
|
||||
and not torch._dynamo.config.assume_static_by_default
|
||||
):
|
||||
self.assertExpectedInline(
|
||||
failure_reason,
|
||||
"""tensor 'L['a']' stride mismatch at index 1. expected 1, actual 3""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
failure_reason,
|
||||
"""tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
@ -665,21 +692,8 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
|
||||
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
@patch("torch._dynamo.config.dynamic_shapes", False)
|
||||
def test_multiple_aot_autograd_calls_dupe_args(self):
|
||||
def maybe_dupe_op(x):
|
||||
y = x + 1
|
||||
z = x + 2
|
||||
if x.numel() < 5:
|
||||
return y, y
|
||||
else:
|
||||
return y, z
|
||||
|
||||
aten = torch.ops.aten
|
||||
lib = torch.library.Library("custom", "DEF")
|
||||
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
||||
|
||||
# this is just dealing with the fact that
|
||||
# aot_module_simplified expects submods to always return tuples/lists
|
||||
class WrapperModule(torch.nn.Module):
|
||||
|
@ -4,6 +4,7 @@ from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
try:
|
||||
from . import (
|
||||
test_aot_autograd,
|
||||
test_ctx_manager,
|
||||
test_export,
|
||||
test_functions,
|
||||
@ -14,6 +15,7 @@ try:
|
||||
test_subgraphs,
|
||||
)
|
||||
except ImportError:
|
||||
import test_aot_autograd
|
||||
import test_ctx_manager
|
||||
import test_export
|
||||
import test_functions
|
||||
@ -82,6 +84,7 @@ tests = [
|
||||
test_export.ExportTests,
|
||||
test_subgraphs.SubGraphTests,
|
||||
test_higher_order_ops.HigherOrderOpTests,
|
||||
test_aot_autograd.AotAutogradFallbackTests,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
|
@ -1577,9 +1577,6 @@ inplace_symbolic_tensor_failures = {
|
||||
xfail('unique', ''),
|
||||
# in-place has a different signature than out-of-place
|
||||
xfail('uniform', ''),
|
||||
# Views
|
||||
xfail('t', ''),
|
||||
xfail('transpose', ''),
|
||||
}
|
||||
|
||||
# Copies inputs to inplace operations to avoid inplace modifications
|
||||
|
@ -1954,13 +1954,15 @@ def aot_wrapper_dedupe(
|
||||
duped_arg_len = len(flat_args)
|
||||
|
||||
j = 0 # index into deduped_flat_args
|
||||
for i, t in enumerate(flat_args):
|
||||
if t in seen_args:
|
||||
keep_arg_mask.append(False)
|
||||
add_dupe_map.append(seen_args[t])
|
||||
continue
|
||||
for t in flat_args:
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t in seen_args:
|
||||
keep_arg_mask.append(False)
|
||||
add_dupe_map.append(seen_args[t])
|
||||
continue
|
||||
seen_args[t] = j
|
||||
|
||||
keep_arg_mask.append(True)
|
||||
seen_args[t] = j
|
||||
add_dupe_map.append(j)
|
||||
j += 1
|
||||
assert len(add_dupe_map) == duped_arg_len, (
|
||||
|
@ -3103,6 +3103,51 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
||||
return self.new_empty(result_size)
|
||||
|
||||
|
||||
@register_meta(torch.ops.aten.transpose_)
|
||||
def transpose_(self, dim0, dim1):
|
||||
assert self.layout not in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
|
||||
|
||||
ndims = self.ndim
|
||||
|
||||
dim0 = maybe_wrap_dim(dim0, ndims)
|
||||
dim1 = maybe_wrap_dim(dim1, ndims)
|
||||
|
||||
if dim0 == dim1:
|
||||
return self
|
||||
|
||||
size = list(self.size())
|
||||
stride = list(self.stride())
|
||||
|
||||
stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
|
||||
size[dim0], size[dim1] = size[dim1], size[dim0]
|
||||
|
||||
self.as_strided_(size, stride)
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(torch.ops.aten.t_)
|
||||
def t_(self):
|
||||
ndims = self.ndim
|
||||
|
||||
if self.is_sparse:
|
||||
sparse_dim = self.sparse_dim()
|
||||
dense_dim = self.dense_dim()
|
||||
assert (
|
||||
sparse_dim <= 2 and dense_dim == 0
|
||||
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950
|
||||
else:
|
||||
assert (
|
||||
self.dim() <= 2
|
||||
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
|
||||
|
||||
return transpose_(self, 0, 0 if ndims < 2 else 1)
|
||||
|
||||
|
||||
# We must also trigger meta registrations from PrimTorch ref
|
||||
# decompositions
|
||||
import torch._refs
|
||||
|
Reference in New Issue
Block a user