mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland] Allow setting grad_dtype on leaf tensors (#164751)
ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815 Fixes #ISSUE_NUMBER Relanding due to internal weirdness. Separate PR to codev w/o ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
001e1d2637
commit
71aefd5595
@ -128,6 +128,7 @@ TORCH_API void set_grad_accumulator(
|
||||
/// if it still exists. If the gradient accumulator function has been
|
||||
/// destroyed, returns a `nullptr`.
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
|
||||
|
||||
/// Gets the gradient accumulator of the `Variable` if it has one, or else
|
||||
/// create one on the fly and return it.
|
||||
@ -253,6 +254,13 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
// correctly when this variable is passed to another function.
|
||||
uint32_t output_nr_;
|
||||
|
||||
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
|
||||
std::optional<at::ScalarType> grad_dtype_;
|
||||
|
||||
// When true, allows gradient dtype to be different from tensor dtype,
|
||||
// bypassing dtype casting and validation in the autograd engine.
|
||||
bool allow_grad_dtype_mismatch_{false};
|
||||
|
||||
// Mutex to ensure that concurrent read operations that modify internal
|
||||
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
|
||||
// fw_grad() and set_fw_grad()
|
||||
@ -293,6 +301,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
uint64_t level,
|
||||
bool is_inplace_op) override;
|
||||
|
||||
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
|
||||
|
||||
void set_grad_dtype(
|
||||
const std::optional<at::ScalarType>& grad_dtype,
|
||||
const at::TensorBase& self);
|
||||
|
||||
AutogradMeta(
|
||||
at::TensorImpl* self_impl = nullptr,
|
||||
bool requires_grad = false,
|
||||
@ -940,6 +954,11 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet dispatch_keys,
|
||||
torch::jit::Stack* stack) const override;
|
||||
std::optional<c10::ScalarType> grad_dtype(
|
||||
const at::TensorBase&) const override;
|
||||
void set_grad_dtype(
|
||||
const at::TensorBase&,
|
||||
const std::optional<c10::ScalarType>&) const override;
|
||||
};
|
||||
|
||||
namespace utils {
|
||||
|
Reference in New Issue
Block a user