[reland] Allow setting grad_dtype on leaf tensors (#164751)

ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815

Fixes #ISSUE_NUMBER

Relanding due to internal weirdness. Separate PR to codev w/o ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-10-08 20:23:13 +00:00
committed by PyTorch MergeBot
parent 001e1d2637
commit 71aefd5595
20 changed files with 377 additions and 41 deletions

View File

@ -128,6 +128,7 @@ TORCH_API void set_grad_accumulator(
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
@ -253,6 +254,13 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
std::optional<at::ScalarType> grad_dtype_;
// When true, allows gradient dtype to be different from tensor dtype,
// bypassing dtype casting and validation in the autograd engine.
bool allow_grad_dtype_mismatch_{false};
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
// fw_grad() and set_fw_grad()
@ -293,6 +301,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
uint64_t level,
bool is_inplace_op) override;
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
void set_grad_dtype(
const std::optional<at::ScalarType>& grad_dtype,
const at::TensorBase& self);
AutogradMeta(
at::TensorImpl* self_impl = nullptr,
bool requires_grad = false,
@ -940,6 +954,11 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const override;
std::optional<c10::ScalarType> grad_dtype(
const at::TensorBase&) const override;
void set_grad_dtype(
const at::TensorBase&,
const std::optional<c10::ScalarType>&) const override;
};
namespace utils {