Revert "Support setting grad_dtype on leaf tensors (#162815)"

This reverts commit dca73982c53e9f99f96246b5d9ed9bab83c7423f.

Reverted https://github.com/pytorch/pytorch/pull/162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](https://github.com/pytorch/pytorch/pull/162815#issuecomment-3367498501))
This commit is contained in:
PyTorch MergeBot
2025-10-03 23:14:28 +00:00
parent fac6f20ae3
commit 3ddf2018d0
19 changed files with 41 additions and 375 deletions

View File

@ -173,12 +173,4 @@ unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)>
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
}
std::optional<ScalarType> TensorBase::grad_dtype() const {
return impl::GetVariableHooks()->grad_dtype(*this);
}
void TensorBase::set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const {
return impl::GetVariableHooks()->set_grad_dtype(*this, grad_dtype);
}
} // namespace at

View File

@ -930,10 +930,6 @@ public:
const TensorBase& requires_grad_(bool _requires_grad=true) const;
std::optional<ScalarType> grad_dtype() const;
void set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const;
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -68,8 +68,6 @@ struct TORCH_API VariableHooksInterface {
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const = 0;
virtual std::optional<c10::ScalarType> grad_dtype(const TensorBase&) const = 0;
virtual void set_grad_dtype(const TensorBase&, const std::optional<c10::ScalarType>&) const = 0;
};
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);

View File

@ -140,7 +140,7 @@ class GraphModule(torch.nn.Module):
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@ -171,7 +171,7 @@ class GraphModule(torch.nn.Module):
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@ -255,7 +255,7 @@ class GraphModule(torch.nn.Module):
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None

View File

@ -3604,12 +3604,12 @@ class CompiledAutograd0(torch.nn.Module):
unwrap_maybe_dynamic_int_18 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_23); getitem_23 = None
unwrap_maybe_dynamic_int_19 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_24); getitem_24 = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True, 6)]); getitem = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
getitem_25 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_25 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
getitem_26 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True, 6)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], True)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_28 = hooks[0]; getitem_28 = None
@ -3631,7 +3631,7 @@ class CompiledAutograd0(torch.nn.Module):
call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None
getitem_36 = call_backward[0]
getitem_37 = call_backward[1]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False, 6)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
getitem_39 = validate_outputs_2[0]
call_accumulate_grad_1 = torch__dynamo_external_utils_call_accumulate_grad(getitem_4, getitem_39, False); getitem_4 = getitem_39 = call_accumulate_grad_1 = None
@ -3866,12 +3866,12 @@ class CompiledAutograd0(torch.nn.Module):
unwrap_maybe_dynamic_int_10 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_12); getitem_12 = None
unwrap_maybe_dynamic_int_11 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_13); getitem_13 = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False, 6)]); getitem = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_14 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1]); getitem_14 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
getitem_15 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False, 6)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3], False)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_17 = hooks[0]
@ -3883,7 +3883,7 @@ class CompiledAutograd0(torch.nn.Module):
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None
getitem_21 = mul_backward0[0]
getitem_22 = mul_backward0[1]; mul_backward0 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False, 6), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False, 6)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7], False)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
getitem_23 = validate_outputs_2[0]
getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None
@ -3892,7 +3892,7 @@ class CompiledAutograd0(torch.nn.Module):
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None
getitem_27 = cos_backward0[0]; cos_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False, 6)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9], False)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None
add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None
@ -3901,7 +3901,7 @@ class CompiledAutograd0(torch.nn.Module):
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
getitem_31 = sin_backward0[0]; sin_backward0 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False, 6)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None
call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None
@ -5266,7 +5266,6 @@ xfail_by_backend = {
"test_dropout_inductor", # functionalize_rng_ops not yet supported
"test_function_with_kwargs", # functionalize_rng_ops not yet supported
"test_module", # functionalize_rng_ops not yet supported
"test_grad_dtype", # AttributeError: args / Float did not match Double
},
"eager": { # will be run without torch.compiling the CA graph
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods

View File

@ -326,8 +326,6 @@ class TestTensorBuiltins(JitTestCase):
# This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways
"names",
# We don't plan to support grad_dtype in TorchScript
"grad_dtype",
}
for p in properties:

View File

@ -3694,130 +3694,6 @@ class TestAutograd(TestCase):
def test_sparse_gather_both_scalar(self):
self._test_sparse_gather((), (), 0)
@skipIfTorchDynamo("grad_dtype not supported in compile")
def test_grad_dtype(self):
leaf = torch.tensor([1.0, 2.0], requires_grad=True)
# Default to tensor's dtype
self.assertEqual(leaf.grad_dtype, torch.float32)
leaf.grad_dtype = torch.float16
self.assertEqual(leaf.grad_dtype, torch.float16)
leaf.grad_dtype = None # Allow any dtype
self.assertIsNone(leaf.grad_dtype)
# get/set grad_dtype is only allowed on leaf tensors
non_leaf = leaf * 2
self.assertFalse(non_leaf.is_leaf)
with self.assertRaisesRegex(
RuntimeError, "grad_dtype can only be accessed on leaf tensors"
):
_ = non_leaf.grad_dtype
with self.assertRaisesRegex(
RuntimeError, "grad_dtype can only be set on leaf tensors"
):
non_leaf.grad_dtype = torch.float16
# Manual setting
x = torch.tensor([1.0, 2.0], requires_grad=True)
grad_match = torch.tensor([1.0, 1.0])
x.grad = grad_match
self.assertEqual(x.grad.dtype, torch.float32)
x.grad = None
x.grad_dtype = torch.float16
grad_mismatch = torch.tensor([1.0, 1.0])
with self.assertRaisesRegex(
RuntimeError,
"attempting to assign a gradient with dtype.*float.*to a tensor with grad_dtype.*Half",
):
x.grad = grad_mismatch
# When grad_dtype is None, any dtype is allowed
x.grad = None
x.grad_dtype = None
grad_any = torch.tensor([1.0, 1.0], dtype=torch.float64)
x.grad = grad_any
self.assertEqual(x.grad.dtype, torch.float64)
# Incoming gradient case
class MismatchedGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp * 2
@staticmethod
def backward(ctx, grad_output):
return grad_output.to(torch.float64)
d = torch.tensor([1.0, 2.0], requires_grad=True)
output = MismatchedGradientFunction.apply(d)
loss = output.sum()
loss.backward()
# Default behavior is to cast to tensor dtype
self.assertEqual(d.grad.dtype, torch.float32)
self.assertTrue(torch.allclose(d.grad, torch.tensor([1.0, 1.0])))
e = torch.tensor([3.0, 4.0], requires_grad=True)
e.grad_dtype = None
output_e = MismatchedGradientFunction.apply(e)
loss_e = output_e.sum()
loss_e.backward()
# No casting is done if set to None.
self.assertTrue(
torch.allclose(e.grad, torch.tensor([1.0, 1.0], dtype=torch.float64))
)
f = torch.tensor([5.0, 6.0], requires_grad=True)
f.grad_dtype = torch.float16 # Expect float16 gradients
output_f = MismatchedGradientFunction.apply(f)
loss_f = output_f.sum()
loss_f.backward()
self.assertTrue(
torch.allclose(f.grad, torch.tensor([1.0, 1.0], dtype=torch.float16))
)
# Setting grad_dtype when gradient already exists
g = torch.tensor([1.0, 2.0], requires_grad=True)
g.grad = torch.tensor([1.0, 1.0])
g.grad_dtype = torch.float32
self.assertEqual(g.grad_dtype, torch.float32)
with self.assertRaisesRegex(
RuntimeError, "Cannot set grad_dtype.*because there is already a gradient"
):
g.grad_dtype = torch.float16
g.grad_dtype = None
self.assertIsNone(g.grad_dtype)
g.grad = None
g.grad_dtype = torch.float16
self.assertEqual(g.grad_dtype, torch.float16)
# Test the case where there is an existing accumulate grad
h = torch.tensor([1.0, 2.0], requires_grad=True)
_ = h.clone()
h.grad_dtype = None
output = MismatchedGradientFunction.apply(h)
output.sum().backward()
self.assertEqual(h.grad.dtype, torch.float64)
# Mixed accumulation cases
k = torch.tensor([1.0, 2.0], requires_grad=True)
k.grad_dtype = None
y = k * 2
y.sum().backward()
k.grad = k.grad.to(torch.bfloat16)
y2 = k * 3
# Doesn't type promote to float32, always coerce to current .grad's dtype.
# This is because the accumulation is done in-place on the existing grad.
self.assertEqual(k.grad.dtype, torch.bfloat16)
l = torch.tensor([3.0, 4.0], requires_grad=True, dtype=torch.bfloat16)
l.grad_dtype = None
z = l * 2
z.sum().backward()
l.grad = l.grad.to(torch.float32)
z2 = l * 3
z2.sum().backward()
self.assertEqual(l.grad.dtype, torch.float32)
def test_gc_in_destructor(self):
"""
Previously, if a Function destructor triggered a garbage collection,

View File

@ -1917,7 +1917,6 @@ class TensorBase(metaclass=_TensorMeta):
names: list[str]
device: _device
dtype: _dtype
grad_dtype: _dtype | None
layout: _layout
real: Tensor
imag: Tensor

View File

@ -6624,36 +6624,6 @@ The attribute will then contain the gradients computed and future calls to
""",
)
add_docstr_all(
"grad_dtype",
r"""
The allowed dtype of :attr:``grad`` for this tensor.
:attr:``grad_dtype`` can be set to a specific dtype or ``None``. By default,
``t.grad_dtype == t.dtype``. When not None, the autograd engine casts
incoming gradients to this dtype. This attribute is only accessible and
settable for leaf tensors.
.. warning::
Use with caution. Diverging the dtypes of a tensor and its gradient may
break downstream systems that assume they match.
Example::
>>> x = torch.tensor([1.0, 2.0], requires_grad=True)
>>> x.grad_dtype
torch.float32
>>> x.grad_dtype = torch.float16
>>> x.grad_dtype
torch.float16
>>> # Allow any gradient dtype
>>> x.grad_dtype = None
>>> x.grad_dtype
""",
)
add_docstr_all(
"retain_grad",
r"""

View File

@ -1,7 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <c10/util/irange.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/input_metadata.h>
#include <torch/csrc/autograd/variable.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -316,30 +314,4 @@ const Variable& AutogradMeta::fw_grad(
return direct_fw_grad;
}
std::optional<at::ScalarType> AutogradMeta::grad_dtype(
const at::TensorBase& self) const {
if (allow_grad_dtype_mismatch_) {
return std::nullopt;
} else if (grad_dtype_.has_value()) {
return grad_dtype_;
} else {
return std::optional<at::ScalarType>(self.scalar_type());
}
}
void AutogradMeta::set_grad_dtype(
const std::optional<at::ScalarType>& grad_dtype,
const at::TensorBase& self) {
TORCH_CHECK(!grad_fn_, "grad_dtype can only be set on leaf tensors.");
if (grad_dtype.has_value()) {
grad_dtype_ = grad_dtype;
allow_grad_dtype_mismatch_ = false;
} else {
allow_grad_dtype_mismatch_ = true;
}
auto grad_acc = impl::try_get_grad_accumulator(self);
if (grad_acc) {
grad_acc->mutable_input_metadata(0).set_grad_dtype(grad_dtype);
}
}
} // namespace torch::autograd

View File

@ -948,17 +948,15 @@ static void validate_outputs_impl(
TORCH_CHECK(
isFloatingType(grad.scalar_type()) ||
(input_is_complex == grad_is_complex));
if (metadata.grad_dtype().has_value()) {
if (grad.scalar_type() != metadata.grad_dtype().value()) {
grad = grad.to(metadata.grad_dtype().value());
}
if (grad.scalar_type() != metadata.grad_dtype().value()) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected dtype ";
ss << metadata.grad_dtype().value() << " but got " << grad.dtype();
TORCH_CHECK(false, format_error(ss.str()));
}
if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
grad.scalar_type()) {
grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
}
if (grad.dtype() != metadata.dtype()) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected dtype ";
ss << metadata.dtype() << " but got " << grad.dtype();
TORCH_CHECK(false, format_error(ss.str()));
}
if (grad.layout() != metadata.layout()) {
// TODO: Currently we only support (*, Sparse) combination for

View File

@ -200,12 +200,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
const at::TensorOptions& options,
c10::SymIntArrayRef shape,
bool is_tensor_subclass,
bool is_nested,
std::optional<at::ScalarType> grad_dtype) noexcept {
bool is_nested) noexcept {
uint32_t input_nr = input_metadata_.size();
auto meta_shape = MetadataShape{std::in_place_type<SymIntSmallVec>, shape};
input_metadata_.emplace_back(
options, meta_shape, is_tensor_subclass, is_nested, grad_dtype);
options, meta_shape, is_tensor_subclass, is_nested);
return input_nr;
}

View File

@ -29,14 +29,12 @@ InputMetadata::InputMetadata(
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass,
bool is_nested,
std::optional<at::ScalarType> grad_dtype)
bool is_nested)
: options_{options},
shape_{std::move(input_shape)},
is_tensor_subclass_{is_tensor_subclass},
is_nested_{is_nested},
was_default_constructed_{false},
grad_dtype_{grad_dtype} {
was_default_constructed_{false} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
@ -46,8 +44,7 @@ InputMetadata::InputMetadata(const at::Tensor& t)
t.options(),
compute_variant_shape(t),
is_python_dispatch(t),
t.is_nested(),
t.grad_dtype()) {}
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
TORCH_CHECK(

View File

@ -38,8 +38,7 @@ struct TORCH_API InputMetadata {
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass,
bool is_nested,
std::optional<at::ScalarType> grad_dtype);
bool is_nested);
InputMetadata(const at::Tensor& t);
const at::TensorOptions& options() const {
@ -98,23 +97,11 @@ struct TORCH_API InputMetadata {
// Danger: not thread safe, caller must protect with lock
SymIntSmallVec& mutable_shape_as_dim_vector();
std::optional<at::ScalarType> grad_dtype() const {
TORCH_INTERNAL_ASSERT(!was_default_constructed_);
return grad_dtype_;
}
void set_grad_dtype(const std::optional<at::ScalarType>& grad_dtype) {
TORCH_INTERNAL_ASSERT(!was_default_constructed_);
grad_dtype_ = grad_dtype;
}
private:
at::Tensor shape_as_tensor() const;
bool is_nestedness_same(const at::Tensor& grad) const;
bool maybe_expandable_to(const at::Tensor& grad) const;
// NB: The engine does not use the dtype from the options, but rather the
// grad_dtype_ field to validate grad_output dtype.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::TensorOptions options_;
MetadataShape shape_;
@ -122,11 +109,5 @@ struct TORCH_API InputMetadata {
bool is_tensor_subclass_ = false;
bool is_nested_ = false;
bool was_default_constructed_ = true;
// The grad_dtype_ field is the dtype that the engine expects the grad to be.
// When nullopt, grad_dtype_ is allowed to be any dtype.
// This field is mutated if THPVariable_set_grad_dtype is called
// and the AccumulateGrad has already been created.
std::optional<at::ScalarType> grad_dtype_;
};
} // namespace torch::autograd

View File

@ -1317,18 +1317,13 @@ static int THPVariable_set_grad(
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
const auto& grad = THPVariable_Unpack(py_grad);
if (var.grad_dtype().has_value()) {
TORCH_CHECK(
grad.dtype() == var.grad_dtype().value(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with grad_dtype '",
var.grad_dtype().value(),
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
"a tensor and its gradient may break downstream systems that assume they match.");
}
TORCH_CHECK(
var.dtype() == grad.dtype(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with dtype '",
var.dtype(),
"'. Please ensure that the gradient and the tensor have the same dtype");
TORCH_CHECK(
var.device().type() == grad.device().type(),
"attempting to assign a gradient with device type '",
@ -1337,11 +1332,8 @@ static int THPVariable_set_grad(
var.device().type(),
"'. Please ensure that the gradient and the tensor are on the same device");
if (grad.layout() != kSparse) {
auto expected_options = var.options().dtype(
var.grad_dtype().has_value() ? var.grad_dtype().value()
: grad.scalar_type());
TORCH_CHECK(
grad.options().type_equal(expected_options),
grad.options().type_equal(var.options()),
"attempting to assign a gradient to a tensor that has data of a different type");
}
TORCH_CHECK(
@ -1847,56 +1839,6 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "grad_dtype");
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
if (!var.grad_dtype().has_value()) {
Py_RETURN_NONE;
} else {
return torch::autograd::utils::wrap(var.grad_dtype().value());
}
END_HANDLE_TH_ERRORS
}
static int THPVariable_set_grad_dtype(
THPVariable* self,
PyObject* obj,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_setter(self, "grad_dtype", obj);
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
THPDtype_Check(obj) || obj == Py_None,
"grad_dtype must be a torch.dtype or None, but got ",
Py_TYPE(obj)->tp_name);
if (var.grad().defined() && obj != Py_None) {
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
TORCH_CHECK(
var.grad().dtype() == new_dtype->scalar_type,
"Cannot set grad_dtype to '",
new_dtype->scalar_type,
"' because there is already a gradient with dtype '",
var.grad().dtype(),
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
"or ensure the new grad_dtype matches the existing gradient's dtype.");
}
std::optional<at::ScalarType> new_dtype;
if (obj != Py_None) {
auto* dtype = reinterpret_cast<THPDtype*>(obj);
new_dtype = dtype->scalar_type;
}
var.set_grad_dtype(new_dtype);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -2055,11 +1997,6 @@ static struct PyGetSetDef THPVariable_properties[] = {
(setter)THPVariable_set_imag,
nullptr,
nullptr},
{"grad_dtype",
(getter)THPVariable_get_grad_dtype,
(setter)THPVariable_set_grad_dtype,
nullptr,
nullptr},
{nullptr}};
static PyMappingMethods THPVariable_as_mapping = {

View File

@ -274,7 +274,7 @@ void set_grad_accumulator(
std::move(grad_accumulator);
}
std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase& self) {
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->grad_accumulator_.lock();
} else {
@ -282,10 +282,6 @@ std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase& self) {
}
}
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
return try_get_grad_accumulator(get_tensor_base(self));
}
std::shared_ptr<Node> grad_accumulator(const Variable& self) {
auto autograd_meta = get_autograd_meta(self);
if (!autograd_meta) {
@ -717,8 +713,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
// intentional
self.unsafeGetTensorImpl()->is_python_dispatch(),
self.is_nested(),
self.grad_dtype());
self.is_nested());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->set_attr_version(current_version);
@ -914,19 +909,4 @@ std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
second->clone_and_set(second_symints, second_tensors));
}
std::optional<c10::ScalarType> VariableHooks::grad_dtype(
const at::TensorBase& self) const {
if (auto* meta = impl::get_autograd_meta(self)) {
return meta->grad_dtype(self);
}
return self.scalar_type();
}
void VariableHooks::set_grad_dtype(
const at::TensorBase& self,
const std::optional<c10::ScalarType>& grad_dtype) const {
auto* meta = impl::materialize_autograd_meta(self);
meta->set_grad_dtype(grad_dtype, self);
}
} // namespace torch::autograd

View File

@ -128,7 +128,6 @@ TORCH_API void set_grad_accumulator(
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
@ -254,13 +253,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
std::optional<at::ScalarType> grad_dtype_;
// When true, allows gradient dtype to be different from tensor dtype,
// bypassing dtype casting and validation in the autograd engine.
bool allow_grad_dtype_mismatch_{false};
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
// fw_grad() and set_fw_grad()
@ -301,12 +293,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
uint64_t level,
bool is_inplace_op) override;
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
void set_grad_dtype(
const std::optional<at::ScalarType>& grad_dtype,
const at::TensorBase& self);
AutogradMeta(
at::TensorImpl* self_impl = nullptr,
bool requires_grad = false,
@ -954,11 +940,6 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const override;
std::optional<c10::ScalarType> grad_dtype(
const at::TensorBase&) const override;
void set_grad_dtype(
const at::TensorBase&,
const std::optional<c10::ScalarType>&) const override;
};
namespace utils {

View File

@ -1458,30 +1458,24 @@ struct IValuePacker<InputMetadata> {
auto tuple = std::make_tuple(
pack_TensorOptions(t.options()),
t.shape_as_dim_vector().vec(),
t.is_tensor_subclass(),
t.grad_dtype());
t.is_tensor_subclass());
return tuple;
}
static InputMetadata unpack(const at::IValue& t) {
auto tuple = t.to<std::tuple<
packed_tensoroptions_t,
std::vector<at::SymInt>,
bool,
std::optional<c10::ScalarType>>>();
auto tuple = t.to<
std::tuple<packed_tensoroptions_t, std::vector<at::SymInt>, bool>>();
return InputMetadata(
unpack_TensorOptions(std::get<0>(tuple)),
SymIntSmallVec(std::get<1>(tuple)),
std::get<2>(tuple),
false,
std::get<3>(tuple));
false);
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<at::TensorOptions>::packed_type(),
IValuePacker<std::vector<at::SymInt>>::packed_type(),
at::BoolType::get(),
IValuePacker<std::optional<at::ScalarType>>::packed_type()});
at::BoolType::get()});
}
};

View File

@ -1351,7 +1351,6 @@ def get_testing_overrides() -> dict[Callable, Callable]:
Tensor._grad.__get__: lambda self: -1,
Tensor._grad_fn.__get__: lambda self: -1,
Tensor.grad_fn.__get__: lambda self: -1,
Tensor.grad_dtype.__get__: lambda self: -1,
Tensor._version.__get__: lambda self: -1,
Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,