Revert "Support setting grad_dtype on leaf tensors (#162815)"

This reverts commit dca73982c53e9f99f96246b5d9ed9bab83c7423f.

Reverted https://github.com/pytorch/pytorch/pull/162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](https://github.com/pytorch/pytorch/pull/162815#issuecomment-3367498501))
This commit is contained in:
PyTorch MergeBot
2025-10-03 23:14:28 +00:00
parent fac6f20ae3
commit 3ddf2018d0
19 changed files with 41 additions and 375 deletions

View File

@ -128,7 +128,6 @@ TORCH_API void set_grad_accumulator(
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
@ -254,13 +253,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
std::optional<at::ScalarType> grad_dtype_;
// When true, allows gradient dtype to be different from tensor dtype,
// bypassing dtype casting and validation in the autograd engine.
bool allow_grad_dtype_mismatch_{false};
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
// fw_grad() and set_fw_grad()
@ -301,12 +293,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
uint64_t level,
bool is_inplace_op) override;
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
void set_grad_dtype(
const std::optional<at::ScalarType>& grad_dtype,
const at::TensorBase& self);
AutogradMeta(
at::TensorImpl* self_impl = nullptr,
bool requires_grad = false,
@ -954,11 +940,6 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const override;
std::optional<c10::ScalarType> grad_dtype(
const at::TensorBase&) const override;
void set_grad_dtype(
const at::TensorBase&,
const std::optional<c10::ScalarType>&) const override;
};
namespace utils {