From 71aefd5595834dd97f38aa978ee32abbd13ac3d6 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 8 Oct 2025 20:23:13 +0000 Subject: [PATCH] [reland] Allow setting grad_dtype on leaf tensors (#164751) ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815 Fixes #ISSUE_NUMBER Relanding due to internal weirdness. Separate PR to codev w/o ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751 Approved by: https://github.com/albanD --- aten/src/ATen/core/Tensor.cpp | 8 ++ aten/src/ATen/core/TensorBase.h | 4 + aten/src/ATen/core/VariableHooksInterface.h | 2 + build_variables.bzl | 2 + test/dynamo/test_backward_higher_order_ops.py | 6 +- test/inductor/test_compiled_autograd.py | 17 +-- test/jit/test_builtins.py | 2 + test/test_autograd.py | 124 ++++++++++++++++++ torch/_C/__init__.pyi.in | 1 + torch/_tensor_docs.py | 30 +++++ torch/csrc/autograd/autograd_meta.cpp | 2 + torch/csrc/autograd/engine.cpp | 20 +-- torch/csrc/autograd/function.h | 5 +- torch/csrc/autograd/input_metadata.cpp | 9 +- torch/csrc/autograd/input_metadata.h | 21 ++- torch/csrc/autograd/python_variable.cpp | 79 +++++++++-- torch/csrc/autograd/variable.cpp | 50 ++++++- torch/csrc/autograd/variable.h | 19 +++ torch/csrc/dynamo/compiled_autograd.h | 16 ++- torch/overrides.py | 1 + 20 files changed, 377 insertions(+), 41 deletions(-) diff --git a/aten/src/ATen/core/Tensor.cpp b/aten/src/ATen/core/Tensor.cpp index fea5d5652c39..c5f887f096cd 100644 --- a/aten/src/ATen/core/Tensor.cpp +++ b/aten/src/ATen/core/Tensor.cpp @@ -173,4 +173,12 @@ unsigned TensorBase::_register_hook(std::function return impl::GetVariableHooks()->_register_hook(*this, std::move(hook)); } +std::optional TensorBase::grad_dtype() const { + return impl::GetVariableHooks()->grad_dtype(*this); +} + +void TensorBase::set_grad_dtype(const std::optional& grad_dtype) const { + return impl::GetVariableHooks()->set_grad_dtype(*this, grad_dtype); +} + } // namespace at diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 5f43738ea0fa..63fe4cad5149 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -930,6 +930,10 @@ public: const TensorBase& requires_grad_(bool _requires_grad=true) const; + std::optional grad_dtype() const; + + void set_grad_dtype(const std::optional& grad_dtype) const; + // View Variables //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/core/VariableHooksInterface.h b/aten/src/ATen/core/VariableHooksInterface.h index f9c0aa4a5fc1..c0f270700e3c 100644 --- a/aten/src/ATen/core/VariableHooksInterface.h +++ b/aten/src/ATen/core/VariableHooksInterface.h @@ -68,6 +68,8 @@ struct TORCH_API VariableHooksInterface { const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) const = 0; + virtual std::optional grad_dtype(const TensorBase&) const = 0; + virtual void set_grad_dtype(const TensorBase&, const std::optional&) const = 0; }; TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); diff --git a/build_variables.bzl b/build_variables.bzl index e4dd849be4fe..ce1c5f1c97b5 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -68,6 +68,8 @@ jit_core_sources = [ # list for the shared files. core_sources_common = [ + # This needs to belong here because it defines the first non-inline virtual + # function, which matters for AutogradMetaInterface's vtable. "torch/csrc/autograd/autograd_meta.cpp", "torch/csrc/autograd/forward_grad.cpp", "torch/csrc/jit/frontend/edit_distance.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 2c60d6ba4cf5..97a380934484 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -140,7 +140,7 @@ class GraphModule(torch.nn.Module): size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None @@ -171,7 +171,7 @@ class GraphModule(torch.nn.Module): size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None @@ -255,7 +255,7 @@ class GraphModule(torch.nn.Module): size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index e0cd8b99a6b3..2612af01f6ff 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3604,12 +3604,12 @@ class CompiledAutograd0(torch.nn.Module): unwrap_maybe_dynamic_int_18 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_23); getitem_23 = None unwrap_maybe_dynamic_int_19 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_24); getitem_24 = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True, 6)]); getitem = None getitem_25 = validate_outputs[0]; validate_outputs = None sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_25 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None getitem_26 = sum_backward0[0]; sum_backward0 = None - validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True, 6)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None getitem_28 = hooks[0]; getitem_28 = None @@ -3631,7 +3631,7 @@ class CompiledAutograd0(torch.nn.Module): call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None getitem_36 = call_backward[0] getitem_37 = call_backward[1]; call_backward = None - validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False, 6)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None getitem_39 = validate_outputs_2[0] call_accumulate_grad_1 = torch__dynamo_external_utils_call_accumulate_grad(getitem_4, getitem_39, False); getitem_4 = getitem_39 = call_accumulate_grad_1 = None @@ -3866,12 +3866,12 @@ class CompiledAutograd0(torch.nn.Module): unwrap_maybe_dynamic_int_10 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_12); getitem_12 = None unwrap_maybe_dynamic_int_11 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_13); getitem_13 = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False, 6)]); getitem = None getitem_14 = validate_outputs[0]; validate_outputs = None sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_14 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None getitem_15 = sum_backward0[0]; sum_backward0 = None - validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False, 6)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None getitem_17 = hooks[0] @@ -3883,7 +3883,7 @@ class CompiledAutograd0(torch.nn.Module): mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None getitem_21 = mul_backward0[0] getitem_22 = mul_backward0[1]; mul_backward0 = None - validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False, 6)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None getitem_23 = validate_outputs_2[0] getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None @@ -3892,7 +3892,7 @@ class CompiledAutograd0(torch.nn.Module): call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None getitem_27 = cos_backward0[0]; cos_backward0 = None - validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None + validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False, 6)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None @@ -3901,7 +3901,7 @@ class CompiledAutograd0(torch.nn.Module): call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None getitem_31 = sin_backward0[0]; sin_backward0 = None - validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None + validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False, 6)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None @@ -5266,6 +5266,7 @@ xfail_by_backend = { "test_dropout_inductor", # functionalize_rng_ops not yet supported "test_function_with_kwargs", # functionalize_rng_ops not yet supported "test_module", # functionalize_rng_ops not yet supported + "test_grad_dtype", # AttributeError: args / Float did not match Double }, "eager": { # will be run without torch.compiling the CA graph "test_setup_context_when_forward_has_default_args", # autograd.Function with class methods diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index 781080f5deb6..097130b6f164 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -326,6 +326,8 @@ class TestTensorBuiltins(JitTestCase): # This has a longer implementation, maybe not worth copying to # TorchScript if named tensors don't work there anyways "names", + # We don't plan to support grad_dtype in TorchScript + "grad_dtype", } for p in properties: diff --git a/test/test_autograd.py b/test/test_autograd.py index 021659b81122..a94a26afdbb8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3694,6 +3694,130 @@ class TestAutograd(TestCase): def test_sparse_gather_both_scalar(self): self._test_sparse_gather((), (), 0) + @skipIfTorchDynamo("grad_dtype not supported in compile") + def test_grad_dtype(self): + leaf = torch.tensor([1.0, 2.0], requires_grad=True) + # Default to tensor's dtype + self.assertEqual(leaf.grad_dtype, torch.float32) + leaf.grad_dtype = torch.float16 + self.assertEqual(leaf.grad_dtype, torch.float16) + leaf.grad_dtype = None # Allow any dtype + self.assertIsNone(leaf.grad_dtype) + + # get/set grad_dtype is only allowed on leaf tensors + non_leaf = leaf * 2 + self.assertFalse(non_leaf.is_leaf) + with self.assertRaisesRegex( + RuntimeError, "grad_dtype can only be accessed on leaf tensors" + ): + _ = non_leaf.grad_dtype + with self.assertRaisesRegex( + RuntimeError, "grad_dtype can only be set on leaf tensors" + ): + non_leaf.grad_dtype = torch.float16 + + # Manual setting + x = torch.tensor([1.0, 2.0], requires_grad=True) + grad_match = torch.tensor([1.0, 1.0]) + x.grad = grad_match + self.assertEqual(x.grad.dtype, torch.float32) + + x.grad = None + x.grad_dtype = torch.float16 + grad_mismatch = torch.tensor([1.0, 1.0]) + with self.assertRaisesRegex( + RuntimeError, + "attempting to assign a gradient with dtype.*float.*to a tensor with grad_dtype.*Half", + ): + x.grad = grad_mismatch + + # When grad_dtype is None, any dtype is allowed + x.grad = None + x.grad_dtype = None + grad_any = torch.tensor([1.0, 1.0], dtype=torch.float64) + x.grad = grad_any + self.assertEqual(x.grad.dtype, torch.float64) + + # Incoming gradient case + class MismatchedGradientFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp * 2 + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to(torch.float64) + + d = torch.tensor([1.0, 2.0], requires_grad=True) + output = MismatchedGradientFunction.apply(d) + loss = output.sum() + loss.backward() + # Default behavior is to cast to tensor dtype + self.assertEqual(d.grad.dtype, torch.float32) + self.assertTrue(torch.allclose(d.grad, torch.tensor([1.0, 1.0]))) + + e = torch.tensor([3.0, 4.0], requires_grad=True) + e.grad_dtype = None + output_e = MismatchedGradientFunction.apply(e) + loss_e = output_e.sum() + loss_e.backward() + # No casting is done if set to None. + self.assertTrue( + torch.allclose(e.grad, torch.tensor([1.0, 1.0], dtype=torch.float64)) + ) + + f = torch.tensor([5.0, 6.0], requires_grad=True) + f.grad_dtype = torch.float16 # Expect float16 gradients + output_f = MismatchedGradientFunction.apply(f) + loss_f = output_f.sum() + loss_f.backward() + self.assertTrue( + torch.allclose(f.grad, torch.tensor([1.0, 1.0], dtype=torch.float16)) + ) + + # Setting grad_dtype when gradient already exists + g = torch.tensor([1.0, 2.0], requires_grad=True) + g.grad = torch.tensor([1.0, 1.0]) + g.grad_dtype = torch.float32 + self.assertEqual(g.grad_dtype, torch.float32) + with self.assertRaisesRegex( + RuntimeError, "Cannot set grad_dtype.*because there is already a gradient" + ): + g.grad_dtype = torch.float16 + g.grad_dtype = None + self.assertIsNone(g.grad_dtype) + g.grad = None + g.grad_dtype = torch.float16 + self.assertEqual(g.grad_dtype, torch.float16) + + # Test the case where there is an existing accumulate grad + h = torch.tensor([1.0, 2.0], requires_grad=True) + _ = h.clone() + h.grad_dtype = None + output = MismatchedGradientFunction.apply(h) + output.sum().backward() + self.assertEqual(h.grad.dtype, torch.float64) + + # Mixed accumulation cases + k = torch.tensor([1.0, 2.0], requires_grad=True) + k.grad_dtype = None + y = k * 2 + y.sum().backward() + k.grad = k.grad.to(torch.bfloat16) + y2 = k * 3 + # Doesn't type promote to float32, always coerce to current .grad's dtype. + # This is because the accumulation is done in-place on the existing grad. + self.assertEqual(k.grad.dtype, torch.bfloat16) + + l = torch.tensor([3.0, 4.0], requires_grad=True, dtype=torch.bfloat16) + l.grad_dtype = None + z = l * 2 + z.sum().backward() + l.grad = l.grad.to(torch.float32) + z2 = l * 3 + z2.sum().backward() + self.assertEqual(l.grad.dtype, torch.float32) + def test_gc_in_destructor(self): """ Previously, if a Function destructor triggered a garbage collection, diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a6885945e55e..2f6ad3f6de67 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1919,6 +1919,7 @@ class TensorBase(metaclass=_TensorMeta): names: list[str] device: _device dtype: _dtype + grad_dtype: _dtype | None layout: _layout real: Tensor imag: Tensor diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 33a2184b71f5..bc5ed9d510d5 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -6624,6 +6624,36 @@ The attribute will then contain the gradients computed and future calls to """, ) +add_docstr_all( + "grad_dtype", + r""" +The allowed dtype of :attr:``grad`` for this tensor. + +:attr:``grad_dtype`` can be set to a specific dtype or ``None``. By default, +``t.grad_dtype == t.dtype``. When not None, the autograd engine casts +incoming gradients to this dtype. This attribute is only accessible and +settable for leaf tensors. + +.. warning:: + Use with caution. Diverging the dtypes of a tensor and its gradient may + break downstream systems that assume they match. + +Example:: + + >>> x = torch.tensor([1.0, 2.0], requires_grad=True) + >>> x.grad_dtype + torch.float32 + + >>> x.grad_dtype = torch.float16 + >>> x.grad_dtype + torch.float16 + + >>> # Allow any gradient dtype + >>> x.grad_dtype = None + >>> x.grad_dtype +""", +) + add_docstr_all( "retain_grad", r""" diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index b1ef5b3a76a4..072501cbcf04 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -1,5 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include +#include #include #ifndef AT_PER_OPERATOR_HEADERS diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 9dc0251fb661..f92af4994fd5 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -948,15 +948,17 @@ static void validate_outputs_impl( TORCH_CHECK( isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex)); - if (c10::typeMetaToScalarType(metadata.options().dtype()) != - grad.scalar_type()) { - grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype())); - } - if (grad.dtype() != metadata.dtype()) { - std::stringstream ss; - ss << "invalid gradient at index " << i << " - expected dtype "; - ss << metadata.dtype() << " but got " << grad.dtype(); - TORCH_CHECK(false, format_error(ss.str())); + + if (metadata.grad_dtype().has_value()) { + if (grad.scalar_type() != metadata.grad_dtype().value()) { + grad = grad.to(metadata.grad_dtype().value()); + } + if (grad.scalar_type() != metadata.grad_dtype().value()) { + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected dtype "; + ss << metadata.grad_dtype().value() << " but got " << grad.dtype(); + TORCH_CHECK(false, format_error(ss.str())); + } } if (grad.layout() != metadata.layout()) { // TODO: Currently we only support (*, Sparse) combination for diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index fba950bbcec5..ca97c43ca726 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -200,11 +200,12 @@ struct TORCH_API Node : std::enable_shared_from_this { const at::TensorOptions& options, c10::SymIntArrayRef shape, bool is_tensor_subclass, - bool is_nested) noexcept { + bool is_nested, + std::optional grad_dtype) noexcept { uint32_t input_nr = input_metadata_.size(); auto meta_shape = MetadataShape{std::in_place_type, shape}; input_metadata_.emplace_back( - options, meta_shape, is_tensor_subclass, is_nested); + options, meta_shape, is_tensor_subclass, is_nested, grad_dtype); return input_nr; } diff --git a/torch/csrc/autograd/input_metadata.cpp b/torch/csrc/autograd/input_metadata.cpp index 74a39ed68381..f43368bbded0 100644 --- a/torch/csrc/autograd/input_metadata.cpp +++ b/torch/csrc/autograd/input_metadata.cpp @@ -29,12 +29,14 @@ InputMetadata::InputMetadata( const at::TensorOptions& options, MetadataShape input_shape, bool is_tensor_subclass, - bool is_nested) + bool is_nested, + std::optional grad_dtype) : options_{options}, shape_{std::move(input_shape)}, is_tensor_subclass_{is_tensor_subclass}, is_nested_{is_nested}, - was_default_constructed_{false} { + was_default_constructed_{false}, + grad_dtype_{grad_dtype} { auto device_ = options.device(); stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_); } @@ -44,7 +46,8 @@ InputMetadata::InputMetadata(const at::Tensor& t) t.options(), compute_variant_shape(t), is_python_dispatch(t), - t.is_nested()) {} + t.is_nested(), + t.grad_dtype()) {} at::Tensor InputMetadata::zeros_like() const { TORCH_CHECK( diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 1f74e72cae7c..1facbf345bc6 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -38,7 +38,8 @@ struct TORCH_API InputMetadata { const at::TensorOptions& options, MetadataShape input_shape, bool is_tensor_subclass, - bool is_nested); + bool is_nested, + std::optional grad_dtype); InputMetadata(const at::Tensor& t); const at::TensorOptions& options() const { @@ -97,11 +98,23 @@ struct TORCH_API InputMetadata { // Danger: not thread safe, caller must protect with lock SymIntSmallVec& mutable_shape_as_dim_vector(); + std::optional grad_dtype() const { + TORCH_INTERNAL_ASSERT(!was_default_constructed_); + return grad_dtype_; + } + + void set_grad_dtype(const std::optional& grad_dtype) { + TORCH_INTERNAL_ASSERT(!was_default_constructed_); + grad_dtype_ = grad_dtype; + } + private: at::Tensor shape_as_tensor() const; bool is_nestedness_same(const at::Tensor& grad) const; bool maybe_expandable_to(const at::Tensor& grad) const; + // NB: The engine does not use the dtype from the options, but rather the + // grad_dtype_ field to validate grad_output dtype. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const at::TensorOptions options_; MetadataShape shape_; @@ -109,5 +122,11 @@ struct TORCH_API InputMetadata { bool is_tensor_subclass_ = false; bool is_nested_ = false; bool was_default_constructed_ = true; + + // The grad_dtype_ field is the dtype that the engine expects the grad to be. + // When nullopt, grad_dtype_ is allowed to be any dtype. + // This field is mutated if THPVariable_set_grad_dtype is called + // and the AccumulateGrad has already been created. + std::optional grad_dtype_; }; } // namespace torch::autograd diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 91f4bffbd103..3ede97905aee 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1319,13 +1319,18 @@ static int THPVariable_set_grad( self != (THPVariable*)py_grad, "can't assign Variable as its own grad"); const auto& grad = THPVariable_Unpack(py_grad); - TORCH_CHECK( - var.dtype() == grad.dtype(), - "attempting to assign a gradient with dtype '", - grad.dtype(), - "' to a tensor with dtype '", - var.dtype(), - "'. Please ensure that the gradient and the tensor have the same dtype"); + if (var.grad_dtype().has_value()) { + TORCH_CHECK( + grad.dtype() == var.grad_dtype().value(), + "attempting to assign a gradient with dtype '", + grad.dtype(), + "' to a tensor with grad_dtype '", + var.grad_dtype().value(), + "'. The gradient must match the tensor's grad_dtype (defaults to the tensor's " + "dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or " + "None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of " + "a tensor and its gradient may break downstream systems that assume they match."); + } TORCH_CHECK( var.device().type() == grad.device().type(), "attempting to assign a gradient with device type '", @@ -1334,8 +1339,11 @@ static int THPVariable_set_grad( var.device().type(), "'. Please ensure that the gradient and the tensor are on the same device"); if (grad.layout() != kSparse) { + auto expected_options = var.options().dtype( + var.grad_dtype().has_value() ? var.grad_dtype().value() + : grad.scalar_type()); TORCH_CHECK( - grad.options().type_equal(var.options()), + grad.options().type_equal(expected_options), "attempting to assign a gradient to a tensor that has data of a different type"); } TORCH_CHECK( @@ -1841,6 +1849,56 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } +static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) { + HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject*)self)) { + return handle_torch_function_getter(self, "grad_dtype"); + } + const auto& var = THPVariable_Unpack(self); + TORCH_CHECK( + !var.grad_fn(), "grad_dtype can only be accessed on leaf tensors."); + if (!var.grad_dtype().has_value()) { + Py_RETURN_NONE; + } else { + return torch::autograd::utils::wrap(var.grad_dtype().value()); + } + END_HANDLE_TH_ERRORS +} + +static int THPVariable_set_grad_dtype( + THPVariable* self, + PyObject* obj, + void* unused) { + HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject*)self)) { + return handle_torch_function_setter(self, "grad_dtype", obj); + } + const auto& var = THPVariable_Unpack(self); + TORCH_CHECK( + THPDtype_Check(obj) || obj == Py_None, + "grad_dtype must be a torch.dtype or None, but got ", + Py_TYPE(obj)->tp_name); + if (var.grad().defined() && obj != Py_None) { + auto new_dtype = reinterpret_cast(obj); + TORCH_CHECK( + var.grad().dtype() == new_dtype->scalar_type, + "Cannot set grad_dtype to '", + new_dtype->scalar_type, + "' because there is already a gradient with dtype '", + var.grad().dtype(), + "'. Please clear the gradient (.grad = None) before changing grad_dtype, " + "or ensure the new grad_dtype matches the existing gradient's dtype."); + } + std::optional new_dtype; + if (obj != Py_None) { + auto* dtype = reinterpret_cast(obj); + new_dtype = dtype->scalar_type; + } + var.set_grad_dtype(new_dtype); + return 0; + END_HANDLE_TH_ERRORS_RET(-1) +} + static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function((PyObject*)self)) { @@ -1999,6 +2057,11 @@ static struct PyGetSetDef THPVariable_properties[] = { (setter)THPVariable_set_imag, nullptr, nullptr}, + {"grad_dtype", + (getter)THPVariable_get_grad_dtype, + (setter)THPVariable_set_grad_dtype, + nullptr, + nullptr}, {nullptr}}; static PyMappingMethods THPVariable_as_mapping = { diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 9610c008c10c..b559ba44bf52 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -274,7 +274,7 @@ void set_grad_accumulator( std::move(grad_accumulator); } -std::shared_ptr try_get_grad_accumulator(const Variable& self) { +std::shared_ptr try_get_grad_accumulator(const at::TensorBase& self) { if (get_autograd_meta(self)) { return get_autograd_meta(self)->grad_accumulator_.lock(); } else { @@ -282,6 +282,10 @@ std::shared_ptr try_get_grad_accumulator(const Variable& self) { } } +std::shared_ptr try_get_grad_accumulator(const Variable& self) { + return try_get_grad_accumulator(get_tensor_base(self)); +} + std::shared_ptr grad_accumulator(const Variable& self) { auto autograd_meta = get_autograd_meta(self); if (!autograd_meta) { @@ -713,7 +717,8 @@ const std::shared_ptr& VariableHooks::grad_fn( self.sym_sizes(), // Note: sizes(), not base_.sizes(), is // intentional self.unsafeGetTensorImpl()->is_python_dispatch(), - self.is_nested()); + self.is_nested(), + self.grad_dtype()); diff_view_meta->grad_fn_ = std::move(fn); } diff_view_meta->set_attr_version(current_version); @@ -909,4 +914,45 @@ std::unique_ptr ChainedViewFunc::clone_and_set( second->clone_and_set(second_symints, second_tensors)); } +std::optional VariableHooks::grad_dtype( + const at::TensorBase& self) const { + if (auto* meta = impl::get_autograd_meta(self)) { + return meta->grad_dtype(self); + } + return self.scalar_type(); +} + +void VariableHooks::set_grad_dtype( + const at::TensorBase& self, + const std::optional& grad_dtype) const { + auto* meta = impl::materialize_autograd_meta(self); + meta->set_grad_dtype(grad_dtype, self); +} + +std::optional AutogradMeta::grad_dtype( + const at::TensorBase& self) const { + if (allow_grad_dtype_mismatch_) { + return std::nullopt; + } else if (grad_dtype_.has_value()) { + return grad_dtype_; + } else { + return std::optional(self.scalar_type()); + } +} +void AutogradMeta::set_grad_dtype( + const std::optional& grad_dtype, + const at::TensorBase& self) { + TORCH_CHECK(!grad_fn_, "grad_dtype can only be set on leaf tensors."); + if (grad_dtype.has_value()) { + grad_dtype_ = grad_dtype; + allow_grad_dtype_mismatch_ = false; + } else { + allow_grad_dtype_mismatch_ = true; + } + auto grad_acc = impl::try_get_grad_accumulator(self); + if (grad_acc) { + grad_acc->mutable_input_metadata(0).set_grad_dtype(grad_dtype); + } +} + } // namespace torch::autograd diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index dfffd3d97095..2ed4a1e8fd5a 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -128,6 +128,7 @@ TORCH_API void set_grad_accumulator( /// if it still exists. If the gradient accumulator function has been /// destroyed, returns a `nullptr`. TORCH_API std::shared_ptr try_get_grad_accumulator(const Variable&); +TORCH_API std::shared_ptr try_get_grad_accumulator(const at::TensorBase&); /// Gets the gradient accumulator of the `Variable` if it has one, or else /// create one on the fly and return it. @@ -253,6 +254,13 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { // correctly when this variable is passed to another function. uint32_t output_nr_; + // The dtype of the grad field; when nullopt, defaults to tensor's dtype. + std::optional grad_dtype_; + + // When true, allows gradient dtype to be different from tensor dtype, + // bypassing dtype casting and validation in the autograd engine. + bool allow_grad_dtype_mismatch_{false}; + // Mutex to ensure that concurrent read operations that modify internal // state are still thread-safe. Used by grad_fn(), grad_accumulator(), // fw_grad() and set_fw_grad() @@ -293,6 +301,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { uint64_t level, bool is_inplace_op) override; + std::optional grad_dtype(const at::TensorBase& self) const; + + void set_grad_dtype( + const std::optional& grad_dtype, + const at::TensorBase& self); + AutogradMeta( at::TensorImpl* self_impl = nullptr, bool requires_grad = false, @@ -940,6 +954,11 @@ struct VariableHooks final : at::impl::VariableHooksInterface { const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) const override; + std::optional grad_dtype( + const at::TensorBase&) const override; + void set_grad_dtype( + const at::TensorBase&, + const std::optional&) const override; }; namespace utils { diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index c5f5fd8d2f18..ca9eb3e638f4 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -1458,24 +1458,30 @@ struct IValuePacker { auto tuple = std::make_tuple( pack_TensorOptions(t.options()), t.shape_as_dim_vector().vec(), - t.is_tensor_subclass()); + t.is_tensor_subclass(), + t.grad_dtype()); return tuple; } static InputMetadata unpack(const at::IValue& t) { - auto tuple = t.to< - std::tuple, bool>>(); + auto tuple = t.to, + bool, + std::optional>>(); return InputMetadata( unpack_TensorOptions(std::get<0>(tuple)), SymIntSmallVec(std::get<1>(tuple)), std::get<2>(tuple), - false); + false, + std::get<3>(tuple)); } static at::TypePtr packed_type() { return at::TupleType::create( {IValuePacker::packed_type(), IValuePacker>::packed_type(), - at::BoolType::get()}); + at::BoolType::get(), + IValuePacker>::packed_type()}); } }; diff --git a/torch/overrides.py b/torch/overrides.py index 0e4c22525312..aa672cd28d96 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1351,6 +1351,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor._grad.__get__: lambda self: -1, Tensor._grad_fn.__get__: lambda self: -1, Tensor.grad_fn.__get__: lambda self: -1, + Tensor.grad_dtype.__get__: lambda self: -1, Tensor._version.__get__: lambda self: -1, Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1, Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,