mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Revert "Support setting grad_dtype on leaf tensors (#162815)"
This reverts commit dca73982c53e9f99f96246b5d9ed9bab83c7423f. Reverted https://github.com/pytorch/pytorch/pull/162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](https://github.com/pytorch/pytorch/pull/162815#issuecomment-3367498501))
This commit is contained in:
@ -1317,18 +1317,13 @@ static int THPVariable_set_grad(
|
||||
self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
|
||||
|
||||
const auto& grad = THPVariable_Unpack(py_grad);
|
||||
if (var.grad_dtype().has_value()) {
|
||||
TORCH_CHECK(
|
||||
grad.dtype() == var.grad_dtype().value(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with grad_dtype '",
|
||||
var.grad_dtype().value(),
|
||||
"'. The gradient must match the tensor's grad_dtype (defaults to the tensor's "
|
||||
"dtype). You can set the tensor's grad_dtype attribute with a specific dtype, or "
|
||||
"None to allow any dtype. Set grad_dtype with caution. Diverging the dtypes of "
|
||||
"a tensor and its gradient may break downstream systems that assume they match.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
var.dtype() == grad.dtype(),
|
||||
"attempting to assign a gradient with dtype '",
|
||||
grad.dtype(),
|
||||
"' to a tensor with dtype '",
|
||||
var.dtype(),
|
||||
"'. Please ensure that the gradient and the tensor have the same dtype");
|
||||
TORCH_CHECK(
|
||||
var.device().type() == grad.device().type(),
|
||||
"attempting to assign a gradient with device type '",
|
||||
@ -1337,11 +1332,8 @@ static int THPVariable_set_grad(
|
||||
var.device().type(),
|
||||
"'. Please ensure that the gradient and the tensor are on the same device");
|
||||
if (grad.layout() != kSparse) {
|
||||
auto expected_options = var.options().dtype(
|
||||
var.grad_dtype().has_value() ? var.grad_dtype().value()
|
||||
: grad.scalar_type());
|
||||
TORCH_CHECK(
|
||||
grad.options().type_equal(expected_options),
|
||||
grad.options().type_equal(var.options()),
|
||||
"attempting to assign a gradient to a tensor that has data of a different type");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -1847,56 +1839,6 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "grad_dtype");
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
!var.grad_fn(), "grad_dtype can only be accessed on leaf tensors.");
|
||||
if (!var.grad_dtype().has_value()) {
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
return torch::autograd::utils::wrap(var.grad_dtype().value());
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static int THPVariable_set_grad_dtype(
|
||||
THPVariable* self,
|
||||
PyObject* obj,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_setter(self, "grad_dtype", obj);
|
||||
}
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
THPDtype_Check(obj) || obj == Py_None,
|
||||
"grad_dtype must be a torch.dtype or None, but got ",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
if (var.grad().defined() && obj != Py_None) {
|
||||
auto new_dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
TORCH_CHECK(
|
||||
var.grad().dtype() == new_dtype->scalar_type,
|
||||
"Cannot set grad_dtype to '",
|
||||
new_dtype->scalar_type,
|
||||
"' because there is already a gradient with dtype '",
|
||||
var.grad().dtype(),
|
||||
"'. Please clear the gradient (.grad = None) before changing grad_dtype, "
|
||||
"or ensure the new grad_dtype matches the existing gradient's dtype.");
|
||||
}
|
||||
std::optional<at::ScalarType> new_dtype;
|
||||
if (obj != Py_None) {
|
||||
auto* dtype = reinterpret_cast<THPDtype*>(obj);
|
||||
new_dtype = dtype->scalar_type;
|
||||
}
|
||||
var.set_grad_dtype(new_dtype);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -2055,11 +1997,6 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
(setter)THPVariable_set_imag,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"grad_dtype",
|
||||
(getter)THPVariable_get_grad_dtype,
|
||||
(setter)THPVariable_set_grad_dtype,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static PyMappingMethods THPVariable_as_mapping = {
|
||||
|
||||
Reference in New Issue
Block a user