mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[reland] Allow setting grad_dtype on leaf tensors (#164751)
ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815 Fixes #ISSUE_NUMBER Relanding due to internal weirdness. Separate PR to codev w/o ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
001e1d2637
commit
71aefd5595
@ -173,4 +173,12 @@ unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)>
|
||||
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
|
||||
}
|
||||
|
||||
std::optional<ScalarType> TensorBase::grad_dtype() const {
|
||||
return impl::GetVariableHooks()->grad_dtype(*this);
|
||||
}
|
||||
|
||||
void TensorBase::set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const {
|
||||
return impl::GetVariableHooks()->set_grad_dtype(*this, grad_dtype);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -930,6 +930,10 @@ public:
|
||||
|
||||
const TensorBase& requires_grad_(bool _requires_grad=true) const;
|
||||
|
||||
std::optional<ScalarType> grad_dtype() const;
|
||||
|
||||
void set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const;
|
||||
|
||||
// View Variables
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -68,6 +68,8 @@ struct TORCH_API VariableHooksInterface {
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet dispatch_keys,
|
||||
torch::jit::Stack* stack) const = 0;
|
||||
virtual std::optional<c10::ScalarType> grad_dtype(const TensorBase&) const = 0;
|
||||
virtual void set_grad_dtype(const TensorBase&, const std::optional<c10::ScalarType>&) const = 0;
|
||||
};
|
||||
|
||||
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
||||
|
@ -68,6 +68,8 @@ jit_core_sources = [
|
||||
# list for the shared files.
|
||||
|
||||
core_sources_common = [
|
||||
# This needs to belong here because it defines the first non-inline virtual
|
||||
# function, which matters for AutogradMetaInterface's vtable.
|
||||
"torch/csrc/autograd/autograd_meta.cpp",
|
||||
"torch/csrc/autograd/forward_grad.cpp",
|
||||
"torch/csrc/jit/frontend/edit_distance.cpp",
|
||||
|
@ -140,7 +140,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
@ -171,7 +171,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
@ -255,7 +255,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
|
@ -3604,12 +3604,12 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
unwrap_maybe_dynamic_int_18 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_23); getitem_23 = None
|
||||
unwrap_maybe_dynamic_int_19 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_24); getitem_24 = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True, 6)]); getitem = None
|
||||
getitem_25 = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_25 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
|
||||
getitem_26 = sum_backward0[0]; sum_backward0 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True, 6)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
|
||||
getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
getitem_28 = hooks[0]; getitem_28 = None
|
||||
@ -3631,7 +3631,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None
|
||||
getitem_36 = call_backward[0]
|
||||
getitem_37 = call_backward[1]; call_backward = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False, 6)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
|
||||
getitem_39 = validate_outputs_2[0]
|
||||
|
||||
call_accumulate_grad_1 = torch__dynamo_external_utils_call_accumulate_grad(getitem_4, getitem_39, False); getitem_4 = getitem_39 = call_accumulate_grad_1 = None
|
||||
@ -3866,12 +3866,12 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
unwrap_maybe_dynamic_int_10 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_12); getitem_12 = None
|
||||
unwrap_maybe_dynamic_int_11 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_13); getitem_13 = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False, 6)]); getitem = None
|
||||
getitem_14 = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_14 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
|
||||
getitem_15 = sum_backward0[0]; sum_backward0 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False, 6)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
|
||||
getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
getitem_17 = hooks[0]
|
||||
@ -3883,7 +3883,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None
|
||||
getitem_21 = mul_backward0[0]
|
||||
getitem_22 = mul_backward0[1]; mul_backward0 = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False, 6)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
|
||||
getitem_23 = validate_outputs_2[0]
|
||||
getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None
|
||||
|
||||
@ -3892,7 +3892,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None
|
||||
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None
|
||||
getitem_27 = cos_backward0[0]; cos_backward0 = None
|
||||
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
|
||||
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False, 6)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
|
||||
getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None
|
||||
add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None
|
||||
|
||||
@ -3901,7 +3901,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None
|
||||
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
|
||||
getitem_31 = sin_backward0[0]; sin_backward0 = None
|
||||
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
|
||||
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False, 6)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
|
||||
getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None
|
||||
|
||||
call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None
|
||||
@ -5266,6 +5266,7 @@ xfail_by_backend = {
|
||||
"test_dropout_inductor", # functionalize_rng_ops not yet supported
|
||||
"test_function_with_kwargs", # functionalize_rng_ops not yet supported
|
||||
"test_module", # functionalize_rng_ops not yet supported
|
||||
"test_grad_dtype", # AttributeError: args / Float did not match Double
|
||||
},
|
||||
"eager": { # will be run without torch.compiling the CA graph
|
||||
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
|
||||
|
@ -326,6 +326,8 @@ class TestTensorBuiltins(JitTestCase):
|
||||
# This has a longer implementation, maybe not worth copying to
|
||||
# TorchScript if named tensors don't work there anyways
|
||||
"names",
|
||||
# We don't plan to support grad_dtype in TorchScript
|
||||
"grad_dtype",
|
||||
}
|
||||
|
||||
for p in properties:
|
||||
|
@ -3694,6 +3694,130 @@ class TestAutograd(TestCase):
|
||||
def test_sparse_gather_both_scalar(self):
|
||||
self._test_sparse_gather((), (), 0)
|
||||
|
||||
@skipIfTorchDynamo("grad_dtype not supported in compile")
|
||||
def test_grad_dtype(self):
|
||||
leaf = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
# Default to tensor's dtype
|
||||
self.assertEqual(leaf.grad_dtype, torch.float32)
|
||||
leaf.grad_dtype = torch.float16
|
||||
self.assertEqual(leaf.grad_dtype, torch.float16)
|
||||
leaf.grad_dtype = None # Allow any dtype
|
||||
self.assertIsNone(leaf.grad_dtype)
|
||||
|
||||
# get/set grad_dtype is only allowed on leaf tensors
|
||||
non_leaf = leaf * 2
|
||||
self.assertFalse(non_leaf.is_leaf)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "grad_dtype can only be accessed on leaf tensors"
|
||||
):
|
||||
_ = non_leaf.grad_dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "grad_dtype can only be set on leaf tensors"
|
||||
):
|
||||
non_leaf.grad_dtype = torch.float16
|
||||
|
||||
# Manual setting
|
||||
x = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
grad_match = torch.tensor([1.0, 1.0])
|
||||
x.grad = grad_match
|
||||
self.assertEqual(x.grad.dtype, torch.float32)
|
||||
|
||||
x.grad = None
|
||||
x.grad_dtype = torch.float16
|
||||
grad_mismatch = torch.tensor([1.0, 1.0])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"attempting to assign a gradient with dtype.*float.*to a tensor with grad_dtype.*Half",
|
||||
):
|
||||
x.grad = grad_mismatch
|
||||
|
||||
# When grad_dtype is None, any dtype is allowed
|
||||
x.grad = None
|
||||
x.grad_dtype = None
|
||||
grad_any = torch.tensor([1.0, 1.0], dtype=torch.float64)
|
||||
x.grad = grad_any
|
||||
self.assertEqual(x.grad.dtype, torch.float64)
|
||||
|
||||
# Incoming gradient case
|
||||
class MismatchedGradientFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp):
|
||||
return inp * 2
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.to(torch.float64)
|
||||
|
||||
d = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
output = MismatchedGradientFunction.apply(d)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
# Default behavior is to cast to tensor dtype
|
||||
self.assertEqual(d.grad.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(d.grad, torch.tensor([1.0, 1.0])))
|
||||
|
||||
e = torch.tensor([3.0, 4.0], requires_grad=True)
|
||||
e.grad_dtype = None
|
||||
output_e = MismatchedGradientFunction.apply(e)
|
||||
loss_e = output_e.sum()
|
||||
loss_e.backward()
|
||||
# No casting is done if set to None.
|
||||
self.assertTrue(
|
||||
torch.allclose(e.grad, torch.tensor([1.0, 1.0], dtype=torch.float64))
|
||||
)
|
||||
|
||||
f = torch.tensor([5.0, 6.0], requires_grad=True)
|
||||
f.grad_dtype = torch.float16 # Expect float16 gradients
|
||||
output_f = MismatchedGradientFunction.apply(f)
|
||||
loss_f = output_f.sum()
|
||||
loss_f.backward()
|
||||
self.assertTrue(
|
||||
torch.allclose(f.grad, torch.tensor([1.0, 1.0], dtype=torch.float16))
|
||||
)
|
||||
|
||||
# Setting grad_dtype when gradient already exists
|
||||
g = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
g.grad = torch.tensor([1.0, 1.0])
|
||||
g.grad_dtype = torch.float32
|
||||
self.assertEqual(g.grad_dtype, torch.float32)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot set grad_dtype.*because there is already a gradient"
|
||||
):
|
||||
g.grad_dtype = torch.float16
|
||||
g.grad_dtype = None
|
||||
self.assertIsNone(g.grad_dtype)
|
||||
g.grad = None
|
||||
g.grad_dtype = torch.float16
|
||||
self.assertEqual(g.grad_dtype, torch.float16)
|
||||
|
||||
# Test the case where there is an existing accumulate grad
|
||||
h = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
_ = h.clone()
|
||||
h.grad_dtype = None
|
||||
output = MismatchedGradientFunction.apply(h)
|
||||
output.sum().backward()
|
||||
self.assertEqual(h.grad.dtype, torch.float64)
|
||||
|
||||
# Mixed accumulation cases
|
||||
k = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
k.grad_dtype = None
|
||||
y = k * 2
|
||||
y.sum().backward()
|
||||
k.grad = k.grad.to(torch.bfloat16)
|
||||
y2 = k * 3
|
||||
# Doesn't type promote to float32, always coerce to current .grad's dtype.
|
||||
# This is because the accumulation is done in-place on the existing grad.
|
||||
self.assertEqual(k.grad.dtype, torch.bfloat16)
|
||||
|
||||
l = torch.tensor([3.0, 4.0], requires_grad=True, dtype=torch.bfloat16)
|
||||
l.grad_dtype = None
|
||||
z = l * 2
|
||||
z.sum().backward()
|
||||
l.grad = l.grad.to(torch.float32)
|
||||
z2 = l * 3
|
||||
z2.sum().backward()
|
||||
self.assertEqual(l.grad.dtype, torch.float32)
|
||||
|
||||
def test_gc_in_destructor(self):
|
||||
"""
|
||||
Previously, if a Function destructor triggered a garbage collection,
|
||||
|
@ -1919,6 +1919,7 @@ class TensorBase(metaclass=_TensorMeta):
|
||||
names: list[str]
|
||||
device: _device
|
||||
dtype: _dtype
|
||||
grad_dtype: _dtype | None
|
||||
layout: _layout
|
||||
real: Tensor
|
||||
imag: Tensor
|
||||
|
@ -6624,6 +6624,36 @@ The attribute will then contain the gradients computed and future calls to
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"grad_dtype",
|
||||
r"""
|
||||
The allowed dtype of :attr:``grad`` for this tensor.
|
||||
|
||||
:attr:``grad_dtype`` can be set to a specific dtype or ``None``. By default,
|
||||
``t.grad_dtype == t.dtype``. When not None, the autograd engine casts
|
||||
incoming gradients to this dtype. This attribute is only accessible and
|
||||
settable for leaf tensors.
|
||||
|
||||
.. warning::
|
||||
Use with caution. Diverging the dtypes of a tensor and its gradient may
|
||||
break downstream systems that assume they match.
|
||||
|
||||
Example::
|
||||
|
||||
>>> x = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
>>> x.grad_dtype
|
||||
torch.float32
|
||||
|
||||
>>> x.grad_dtype = torch.float16
|
||||
>>> x.grad_dtype
|
||||
torch.float16
|
||||
|
||||
>>> # Allow any gradient dtype
|
||||
>>> x.grad_dtype = None
|
||||
>>> x.grad_dtype
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"retain_grad",
|
||||
r"""
|
||||
|
@ -1,5 +1,7 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/input_metadata.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
@ -948,15 +948,17 @@ static void validate_outputs_impl(
|
||||
TORCH_CHECK(
|
||||
isFloatingType(grad.scalar_type()) ||
|
||||
(input_is_complex == grad_is_complex));
|
||||
if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
|
||||
grad.scalar_type()) {
|
||||
grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
|
||||
}
|
||||
if (grad.dtype() != metadata.dtype()) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid gradient at index " << i << " - expected dtype ";
|
||||
ss << metadata.dtype() << " but got " << grad.dtype();
|
||||
TORCH_CHECK(false, format_error(ss.str()));
|
||||
|
||||
if (metadata.grad_dtype().has_value()) {
|
||||
if (grad.scalar_type() != metadata.grad_dtype().value()) {
|
||||
grad = grad.to(metadata.grad_dtype().value());
|
||||
}
|
||||
if (grad.scalar_type() != metadata.grad_dtype().value()) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid gradient at index " << i << " - expected dtype ";
|
||||
ss << metadata.grad_dtype().value() << " but got " << grad.dtype();
|
||||
TORCH_CHECK(false, format_error(ss.str()));
|
||||
}
|
||||
}
|
||||
if (grad.layout() != metadata.layout()) {
|
||||
// TODO: Currently we only support (*, Sparse) combination for
|
||||
|
@ -200,11 +200,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
const at::TensorOptions& options,
|
||||
c10::SymIntArrayRef shape,
|
||||
bool is_tensor_subclass,
|
||||
bool is_nested) noexcept {
|
||||
bool is_nested,
|
||||
std::optional<at::ScalarType> grad_dtype) noexcept {
|
||||
uint32_t input_nr = input_metadata_.size();
|
||||
auto meta_shape = MetadataShape{std::in_place_type<SymIntSmallVec>, shape};
|
||||
input_metadata_.emplace_back(
|
||||
options, meta_shape, is_tensor_subclass, is_nested);
|
||||
options, meta_shape, is_tensor_subclass, is_nested, grad_dtype);
|
||||
return input_nr;
|
||||
}
|
||||
|
||||
|
@ -29,12 +29,14 @@ InputMetadata::InputMetadata(
|
||||
const at::TensorOptions& options,
|
||||
MetadataShape input_shape,
|
||||
bool is_tensor_subclass,
|
||||
bool is_nested)
|
||||
bool is_nested,
|
||||
std::optional<at::ScalarType> grad_dtype)
|
||||
: options_{options},
|
||||
shape_{std::move(input_shape)},
|
||||
is_tensor_subclass_{is_tensor_subclass},
|
||||
is_nested_{is_nested},
|
||||
was_default_constructed_{false} {
|
||||
was_default_constructed_{false},
|
||||
grad_dtype_{grad_dtype} {
|
||||
auto device_ = options.device();
|
||||
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
|
||||
}
|
||||
@ -44,7 +46,8 @@ InputMetadata::InputMetadata(const at::Tensor& t)
|
||||
t.options(),
|
||||
compute_variant_shape(t),
|
||||
is_python_dispatch(t),
|
||||
t.is_nested()) {}
|
||||
t.is_nested(),
|
||||
t.grad_dtype()) {}
|
||||
|
||||
at::Tensor InputMetadata::zeros_like() const {
|
||||
TORCH_CHECK(
|
||||
|
@ -38,7 +38,8 @@ struct TORCH_API InputMetadata {
|
||||
const at::TensorOptions& options,
|
||||
MetadataShape input_shape,
|
||||
bool is_tensor_subclass,
|
||||
bool is_nested);
|
||||
bool is_nested,
|
||||
std::optional<at::ScalarType> grad_dtype);
|
||||
InputMetadata(const at::Tensor& t);
|
||||
|
||||
const at::TensorOptions& options() const {
|
||||
@ -97,11 +98,23 @@ struct TORCH_API InputMetadata {
|
||||
// Danger: not thread safe, caller must protect with lock
|
||||
SymIntSmallVec& mutable_shape_as_dim_vector();
|
||||
|
||||
std::optional<at::ScalarType> grad_dtype() const {
|
||||
TORCH_INTERNAL_ASSERT(!was_default_constructed_);
|
||||
return grad_dtype_;
|
||||
}
|
||||
|
||||
void set_grad_dtype(const std::optional<at::ScalarType>& grad_dtype) {
|
||||
TORCH_INTERNAL_ASSERT(!was_default_constructed_);
|
||||
grad_dtype_ = grad_dtype;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor shape_as_tensor() const;
|
||||
bool is_nestedness_same(const at::Tensor& grad) const;
|
||||
bool maybe_expandable_to(const at::Tensor& grad) const;
|
||||
|
||||
// NB: The engine does not use the dtype from the options, but rather the
|
||||
// grad_dtype_ field to validate grad_output dtype.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const at::TensorOptions options_;
|
||||
MetadataShape shape_;
|
||||
@ -109,5 +122,11 @@ struct TORCH_API InputMetadata {
|
||||
bool is_tensor_subclass_ = false;
|
||||
bool is_nested_ = false;
|
||||
bool was_default_constructed_ = true;
|
||||
|
||||
// The grad_dtype_ field is the dtype that the engine expects the grad to be.
|
||||
// When nullopt, grad_dtype_ is allowed to be any dtype.
|
||||
// This field is mutated if THPVariable_set_grad_dtype is called
|
||||
// and the AccumulateGrad has already been created.
|
||||
std::optional<at::ScalarType> grad_dtype_;
|
||||
};
|
||||
} // namespace torch::autograd
|
||||
|
@ -1319,13 +1319,18 @@ static int THPVariable_set_grad(
|
||||
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
|
||||
|
||||
const auto& grad = THPVariable_Unpack(py_grad);
|
||||
TORCH_CHECK(
|
||||
var.dtype() == grad.dtype(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with dtype '",
|
||||
var.dtype(),
|
||||
"'. Please ensure that the gradient and the tensor have the same dtype");
|
||||
if (var.grad_dtype().has_value()) {
|
||||
TORCH_CHECK(
|
||||
grad.dtype() == var.grad_dtype().value(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with grad_dtype '",
|
||||
var.grad_dtype().value(),
|
||||
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
|
||||
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
|
||||
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
|
||||
"a tensor and its gradient may break downstream systems that assume they match.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
var.device().type() == grad.device().type(),
|
||||
"attempting to assign a gradient with device type '",
|
||||
@ -1334,8 +1339,11 @@ static int THPVariable_set_grad(
|
||||
var.device().type(),
|
||||
"'. Please ensure that the gradient and the tensor are on the same device");
|
||||
if (grad.layout() != kSparse) {
|
||||
auto expected_options = var.options().dtype(
|
||||
var.grad_dtype().has_value() ? var.grad_dtype().value()
|
||||
: grad.scalar_type());
|
||||
TORCH_CHECK(
|
||||
grad.options().type_equal(var.options()),
|
||||
grad.options().type_equal(expected_options),
|
||||
"attempting to assign a gradient to a tensor that has data of a different type");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -1841,6 +1849,56 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "grad_dtype");
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
|
||||
if (!var.grad_dtype().has_value()) {
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
return torch::autograd::utils::wrap(var.grad_dtype().value());
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static int THPVariable_set_grad_dtype(
|
||||
THPVariable* self,
|
||||
PyObject* obj,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_setter(self, "grad_dtype", obj);
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
THPDtype_Check(obj) || obj == Py_None,
|
||||
"grad_dtype must be a torch.dtype or None, but got ",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
if (var.grad().defined() && obj != Py_None) {
|
||||
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
TORCH_CHECK(
|
||||
var.grad().dtype() == new_dtype->scalar_type,
|
||||
"Cannot set grad_dtype to '",
|
||||
new_dtype->scalar_type,
|
||||
"' because there is already a gradient with dtype '",
|
||||
var.grad().dtype(),
|
||||
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
|
||||
"or ensure the new grad_dtype matches the existing gradient's dtype.");
|
||||
}
|
||||
std::optional<at::ScalarType> new_dtype;
|
||||
if (obj != Py_None) {
|
||||
auto* dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
new_dtype = dtype->scalar_type;
|
||||
}
|
||||
var.set_grad_dtype(new_dtype);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -1999,6 +2057,11 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
(setter)THPVariable_set_imag,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"grad_dtype",
|
||||
(getter)THPVariable_get_grad_dtype,
|
||||
(setter)THPVariable_set_grad_dtype,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static PyMappingMethods THPVariable_as_mapping = {
|
||||
|
@ -274,7 +274,7 @@ void set_grad_accumulator(
|
||||
std::move(grad_accumulator);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
|
||||
std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase& self) {
|
||||
if (get_autograd_meta(self)) {
|
||||
return get_autograd_meta(self)->grad_accumulator_.lock();
|
||||
} else {
|
||||
@ -282,6 +282,10 @@ std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
|
||||
return try_get_grad_accumulator(get_tensor_base(self));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> grad_accumulator(const Variable& self) {
|
||||
auto autograd_meta = get_autograd_meta(self);
|
||||
if (!autograd_meta) {
|
||||
@ -713,7 +717,8 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
|
||||
// intentional
|
||||
self.unsafeGetTensorImpl()->is_python_dispatch(),
|
||||
self.is_nested());
|
||||
self.is_nested(),
|
||||
self.grad_dtype());
|
||||
diff_view_meta->grad_fn_ = std::move(fn);
|
||||
}
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
@ -909,4 +914,45 @@ std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
|
||||
second->clone_and_set(second_symints, second_tensors));
|
||||
}
|
||||
|
||||
std::optional<c10::ScalarType> VariableHooks::grad_dtype(
|
||||
const at::TensorBase& self) const {
|
||||
if (auto* meta = impl::get_autograd_meta(self)) {
|
||||
return meta->grad_dtype(self);
|
||||
}
|
||||
return self.scalar_type();
|
||||
}
|
||||
|
||||
void VariableHooks::set_grad_dtype(
|
||||
const at::TensorBase& self,
|
||||
const std::optional<c10::ScalarType>& grad_dtype) const {
|
||||
auto* meta = impl::materialize_autograd_meta(self);
|
||||
meta->set_grad_dtype(grad_dtype, self);
|
||||
}
|
||||
|
||||
std::optional<at::ScalarType> AutogradMeta::grad_dtype(
|
||||
const at::TensorBase& self) const {
|
||||
if (allow_grad_dtype_mismatch_) {
|
||||
return std::nullopt;
|
||||
} else if (grad_dtype_.has_value()) {
|
||||
return grad_dtype_;
|
||||
} else {
|
||||
return std::optional<at::ScalarType>(self.scalar_type());
|
||||
}
|
||||
}
|
||||
void AutogradMeta::set_grad_dtype(
|
||||
const std::optional<at::ScalarType>& grad_dtype,
|
||||
const at::TensorBase& self) {
|
||||
TORCH_CHECK(!grad_fn_, "grad_dtype can only be set on leaf tensors.");
|
||||
if (grad_dtype.has_value()) {
|
||||
grad_dtype_ = grad_dtype;
|
||||
allow_grad_dtype_mismatch_ = false;
|
||||
} else {
|
||||
allow_grad_dtype_mismatch_ = true;
|
||||
}
|
||||
auto grad_acc = impl::try_get_grad_accumulator(self);
|
||||
if (grad_acc) {
|
||||
grad_acc->mutable_input_metadata(0).set_grad_dtype(grad_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
@ -128,6 +128,7 @@ TORCH_API void set_grad_accumulator(
|
||||
/// if it still exists. If the gradient accumulator function has been
|
||||
/// destroyed, returns a `nullptr`.
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
|
||||
|
||||
/// Gets the gradient accumulator of the `Variable` if it has one, or else
|
||||
/// create one on the fly and return it.
|
||||
@ -253,6 +254,13 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
// correctly when this variable is passed to another function.
|
||||
uint32_t output_nr_;
|
||||
|
||||
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
|
||||
std::optional<at::ScalarType> grad_dtype_;
|
||||
|
||||
// When true, allows gradient dtype to be different from tensor dtype,
|
||||
// bypassing dtype casting and validation in the autograd engine.
|
||||
bool allow_grad_dtype_mismatch_{false};
|
||||
|
||||
// Mutex to ensure that concurrent read operations that modify internal
|
||||
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
|
||||
// fw_grad() and set_fw_grad()
|
||||
@ -293,6 +301,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
uint64_t level,
|
||||
bool is_inplace_op) override;
|
||||
|
||||
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
|
||||
|
||||
void set_grad_dtype(
|
||||
const std::optional<at::ScalarType>& grad_dtype,
|
||||
const at::TensorBase& self);
|
||||
|
||||
AutogradMeta(
|
||||
at::TensorImpl* self_impl = nullptr,
|
||||
bool requires_grad = false,
|
||||
@ -940,6 +954,11 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet dispatch_keys,
|
||||
torch::jit::Stack* stack) const override;
|
||||
std::optional<c10::ScalarType> grad_dtype(
|
||||
const at::TensorBase&) const override;
|
||||
void set_grad_dtype(
|
||||
const at::TensorBase&,
|
||||
const std::optional<c10::ScalarType>&) const override;
|
||||
};
|
||||
|
||||
namespace utils {
|
||||
|
@ -1458,24 +1458,30 @@ struct IValuePacker<InputMetadata> {
|
||||
auto tuple = std::make_tuple(
|
||||
pack_TensorOptions(t.options()),
|
||||
t.shape_as_dim_vector().vec(),
|
||||
t.is_tensor_subclass());
|
||||
t.is_tensor_subclass(),
|
||||
t.grad_dtype());
|
||||
return tuple;
|
||||
}
|
||||
static InputMetadata unpack(const at::IValue& t) {
|
||||
auto tuple = t.to<
|
||||
std::tuple<packed_tensoroptions_t, std::vector<at::SymInt>, bool>>();
|
||||
auto tuple = t.to<std::tuple<
|
||||
packed_tensoroptions_t,
|
||||
std::vector<at::SymInt>,
|
||||
bool,
|
||||
std::optional<c10::ScalarType>>>();
|
||||
|
||||
return InputMetadata(
|
||||
unpack_TensorOptions(std::get<0>(tuple)),
|
||||
SymIntSmallVec(std::get<1>(tuple)),
|
||||
std::get<2>(tuple),
|
||||
false);
|
||||
false,
|
||||
std::get<3>(tuple));
|
||||
}
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create(
|
||||
{IValuePacker<at::TensorOptions>::packed_type(),
|
||||
IValuePacker<std::vector<at::SymInt>>::packed_type(),
|
||||
at::BoolType::get()});
|
||||
at::BoolType::get(),
|
||||
IValuePacker<std::optional<at::ScalarType>>::packed_type()});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1351,6 +1351,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
Tensor._grad.__get__: lambda self: -1,
|
||||
Tensor._grad_fn.__get__: lambda self: -1,
|
||||
Tensor.grad_fn.__get__: lambda self: -1,
|
||||
Tensor.grad_dtype.__get__: lambda self: -1,
|
||||
Tensor._version.__get__: lambda self: -1,
|
||||
Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
|
||||
Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
|
||||
|
Reference in New Issue
Block a user