mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
Support setting grad_dtype on leaf tensors (#162815)
`grad_dtype` is a new attribute on Tensor to control gradient dtype: - Access/setting is leaf-only. - grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details) - Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype` - `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict. | `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine | |-----------------------|--------------------------|-----------------------------------------| | **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype | | **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype | | **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is | Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
43848b71d9
commit
dca73982c5
@ -1317,13 +1317,18 @@ static int THPVariable_set_grad(
|
||||
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
|
||||
|
||||
const auto& grad = THPVariable_Unpack(py_grad);
|
||||
TORCH_CHECK(
|
||||
var.dtype() == grad.dtype(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with dtype '",
|
||||
var.dtype(),
|
||||
"'. Please ensure that the gradient and the tensor have the same dtype");
|
||||
if (var.grad_dtype().has_value()) {
|
||||
TORCH_CHECK(
|
||||
grad.dtype() == var.grad_dtype().value(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with grad_dtype '",
|
||||
var.grad_dtype().value(),
|
||||
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
|
||||
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
|
||||
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
|
||||
"a tensor and its gradient may break downstream systems that assume they match.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
var.device().type() == grad.device().type(),
|
||||
"attempting to assign a gradient with device type '",
|
||||
@ -1332,8 +1337,11 @@ static int THPVariable_set_grad(
|
||||
var.device().type(),
|
||||
"'. Please ensure that the gradient and the tensor are on the same device");
|
||||
if (grad.layout() != kSparse) {
|
||||
auto expected_options = var.options().dtype(
|
||||
var.grad_dtype().has_value() ? var.grad_dtype().value()
|
||||
: grad.scalar_type());
|
||||
TORCH_CHECK(
|
||||
grad.options().type_equal(var.options()),
|
||||
grad.options().type_equal(expected_options),
|
||||
"attempting to assign a gradient to a tensor that has data of a different type");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -1839,6 +1847,56 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "grad_dtype");
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
|
||||
if (!var.grad_dtype().has_value()) {
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
return torch::autograd::utils::wrap(var.grad_dtype().value());
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static int THPVariable_set_grad_dtype(
|
||||
THPVariable* self,
|
||||
PyObject* obj,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_setter(self, "grad_dtype", obj);
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
THPDtype_Check(obj) || obj == Py_None,
|
||||
"grad_dtype must be a torch.dtype or None, but got ",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
if (var.grad().defined() && obj != Py_None) {
|
||||
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
TORCH_CHECK(
|
||||
var.grad().dtype() == new_dtype->scalar_type,
|
||||
"Cannot set grad_dtype to '",
|
||||
new_dtype->scalar_type,
|
||||
"' because there is already a gradient with dtype '",
|
||||
var.grad().dtype(),
|
||||
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
|
||||
"or ensure the new grad_dtype matches the existing gradient's dtype.");
|
||||
}
|
||||
std::optional<at::ScalarType> new_dtype;
|
||||
if (obj != Py_None) {
|
||||
auto* dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
new_dtype = dtype->scalar_type;
|
||||
}
|
||||
var.set_grad_dtype(new_dtype);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -1997,6 +2055,11 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
(setter)THPVariable_set_imag,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"grad_dtype",
|
||||
(getter)THPVariable_get_grad_dtype,
|
||||
(setter)THPVariable_set_grad_dtype,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static PyMappingMethods THPVariable_as_mapping = {
|
||||
|
||||
Reference in New Issue
Block a user