Revert "Support map autograd and pytree in/out (#100494)"

This reverts commit b8fa41be9d396d97cfcd53964a228e2f987e104a.

Reverted https://github.com/pytorch/pytorch/pull/100494 on behalf of https://github.com/PaliC due to breaking tests on trunk, please check hud.pytorch.org for the broken tests ([comment](https://github.com/pytorch/pytorch/pull/100494#issuecomment-1550454835))
This commit is contained in:
PyTorch MergeBot
2023-05-16 22:50:18 +00:00
parent b8fa41be9d
commit e69198b043
7 changed files with 95 additions and 617 deletions

View File

@ -5,7 +5,6 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
@ -18,7 +17,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -52,7 +51,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad: