mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support map autograd and pytree in/out (#100494)"
This reverts commit b8fa41be9d396d97cfcd53964a228e2f987e104a. Reverted https://github.com/pytorch/pytorch/pull/100494 on behalf of https://github.com/PaliC due to breaking tests on trunk, please check hud.pytorch.org for the broken tests ([comment](https://github.com/pytorch/pytorch/pull/100494#issuecomment-1550454835))
This commit is contained in:
@ -5,7 +5,6 @@ import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
@ -18,7 +17,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
|
||||
class TestBwdGradients(TestGradients):
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
# This is verified by test_dtypes in test_ops.py
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
@ -52,7 +51,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db)
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
|
Reference in New Issue
Block a user