Support setting grad_dtype on leaf tensors (#162815)

`grad_dtype` is a new attribute on Tensor to control gradient dtype:
- Access/setting is leaf-only.
- grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details)
- Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype`
- `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict.

| `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine |
|-----------------------|--------------------------|-----------------------------------------|
| **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype |
| **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype |
| **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-10-02 11:28:36 -07:00
committed by PyTorch MergeBot
parent 43848b71d9
commit dca73982c5
19 changed files with 375 additions and 41 deletions

View File

@ -1317,13 +1317,18 @@ static int THPVariable_set_grad(
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
const auto& grad = THPVariable_Unpack(py_grad);
TORCH_CHECK(
var.dtype() == grad.dtype(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with dtype '",
var.dtype(),
"'. Please ensure that the gradient and the tensor have the same dtype");
if (var.grad_dtype().has_value()) {
TORCH_CHECK(
grad.dtype() == var.grad_dtype().value(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with grad_dtype '",
var.grad_dtype().value(),
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
"a tensor and its gradient may break downstream systems that assume they match.");
}
TORCH_CHECK(
var.device().type() == grad.device().type(),
"attempting to assign a gradient with device type '",
@ -1332,8 +1337,11 @@ static int THPVariable_set_grad(
var.device().type(),
"'. Please ensure that the gradient and the tensor are on the same device");
if (grad.layout() != kSparse) {
auto expected_options = var.options().dtype(
var.grad_dtype().has_value() ? var.grad_dtype().value()
: grad.scalar_type());
TORCH_CHECK(
grad.options().type_equal(var.options()),
grad.options().type_equal(expected_options),
"attempting to assign a gradient to a tensor that has data of a different type");
}
TORCH_CHECK(
@ -1839,6 +1847,56 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "grad_dtype");
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
if (!var.grad_dtype().has_value()) {
Py_RETURN_NONE;
} else {
return torch::autograd::utils::wrap(var.grad_dtype().value());
}
END_HANDLE_TH_ERRORS
}
static int THPVariable_set_grad_dtype(
THPVariable* self,
PyObject* obj,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_setter(self, "grad_dtype", obj);
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
THPDtype_Check(obj) || obj == Py_None,
"grad_dtype must be a torch.dtype or None, but got ",
Py_TYPE(obj)->tp_name);
if (var.grad().defined() && obj != Py_None) {
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
TORCH_CHECK(
var.grad().dtype() == new_dtype->scalar_type,
"Cannot set grad_dtype to '",
new_dtype->scalar_type,
"' because there is already a gradient with dtype '",
var.grad().dtype(),
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
"or ensure the new grad_dtype matches the existing gradient's dtype.");
}
std::optional<at::ScalarType> new_dtype;
if (obj != Py_None) {
auto* dtype = reinterpret_cast<THPDtype*>(obj);
new_dtype = dtype->scalar_type;
}
var.set_grad_dtype(new_dtype);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -1997,6 +2055,11 @@ static struct PyGetSetDef THPVariable_properties[] = {
(setter)THPVariable_set_imag,
nullptr,
nullptr},
{"grad_dtype",
(getter)THPVariable_get_grad_dtype,
(setter)THPVariable_set_grad_dtype,
nullptr,
nullptr},
{nullptr}};
static PyMappingMethods THPVariable_as_mapping = {