Revert "Support setting grad_dtype on leaf tensors (#162815)"

This reverts commit dca73982c53e9f99f96246b5d9ed9bab83c7423f.

Reverted https://github.com/pytorch/pytorch/pull/162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](https://github.com/pytorch/pytorch/pull/162815#issuecomment-3367498501))
This commit is contained in:
PyTorch MergeBot
2025-10-03 23:14:28 +00:00
parent fac6f20ae3
commit 3ddf2018d0
19 changed files with 41 additions and 375 deletions

View File

@ -1317,18 +1317,13 @@ static int THPVariable_set_grad(
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
const auto& grad = THPVariable_Unpack(py_grad);
if (var.grad_dtype().has_value()) {
TORCH_CHECK(
grad.dtype() == var.grad_dtype().value(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with grad_dtype '",
var.grad_dtype().value(),
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
"a tensor and its gradient may break downstream systems that assume they match.");
}
TORCH_CHECK(
var.dtype() == grad.dtype(),
"attempting to assign a gradient with dtype '",
grad.dtype(),
"' to a tensor with dtype '",
var.dtype(),
"'. Please ensure that the gradient and the tensor have the same dtype");
TORCH_CHECK(
var.device().type() == grad.device().type(),
"attempting to assign a gradient with device type '",
@ -1337,11 +1332,8 @@ static int THPVariable_set_grad(
var.device().type(),
"'. Please ensure that the gradient and the tensor are on the same device");
if (grad.layout() != kSparse) {
auto expected_options = var.options().dtype(
var.grad_dtype().has_value() ? var.grad_dtype().value()
: grad.scalar_type());
TORCH_CHECK(
grad.options().type_equal(expected_options),
grad.options().type_equal(var.options()),
"attempting to assign a gradient to a tensor that has data of a different type");
}
TORCH_CHECK(
@ -1847,56 +1839,6 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "grad_dtype");
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
if (!var.grad_dtype().has_value()) {
Py_RETURN_NONE;
} else {
return torch::autograd::utils::wrap(var.grad_dtype().value());
}
END_HANDLE_TH_ERRORS
}
static int THPVariable_set_grad_dtype(
THPVariable* self,
PyObject* obj,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_setter(self, "grad_dtype", obj);
}
const auto& var = THPVariable_Unpack(self);
TORCH_CHECK(
THPDtype_Check(obj) || obj == Py_None,
"grad_dtype must be a torch.dtype or None, but got ",
Py_TYPE(obj)->tp_name);
if (var.grad().defined() && obj != Py_None) {
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
TORCH_CHECK(
var.grad().dtype() == new_dtype->scalar_type,
"Cannot set grad_dtype to '",
new_dtype->scalar_type,
"' because there is already a gradient with dtype '",
var.grad().dtype(),
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
"or ensure the new grad_dtype matches the existing gradient's dtype.");
}
std::optional<at::ScalarType> new_dtype;
if (obj != Py_None) {
auto* dtype = reinterpret_cast<THPDtype*>(obj);
new_dtype = dtype->scalar_type;
}
var.set_grad_dtype(new_dtype);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -2055,11 +1997,6 @@ static struct PyGetSetDef THPVariable_properties[] = {
(setter)THPVariable_set_imag,
nullptr,
nullptr},
{"grad_dtype",
(getter)THPVariable_get_grad_dtype,
(setter)THPVariable_set_grad_dtype,
nullptr,
nullptr},
{nullptr}};
static PyMappingMethods THPVariable_as_mapping = {