mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support setting grad_dtype on leaf tensors (#162815)"
This reverts commit dca73982c53e9f99f96246b5d9ed9bab83c7423f. Reverted https://github.com/pytorch/pytorch/pull/162815 on behalf of https://github.com/yangw-dev due to break internal test D83850533, see more details below ([comment](https://github.com/pytorch/pytorch/pull/162815#issuecomment-3367498501))
This commit is contained in:
@ -128,7 +128,6 @@ TORCH_API void set_grad_accumulator(
|
||||
/// if it still exists. If the gradient accumulator function has been
|
||||
/// destroyed, returns a `nullptr`.
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
|
||||
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const at::TensorBase&);
|
||||
|
||||
/// Gets the gradient accumulator of the `Variable` if it has one, or else
|
||||
/// create one on the fly and return it.
|
||||
@ -254,13 +253,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
// correctly when this variable is passed to another function.
|
||||
uint32_t output_nr_;
|
||||
|
||||
// The dtype of the grad field; when nullopt, defaults to tensor's dtype.
|
||||
std::optional<at::ScalarType> grad_dtype_;
|
||||
|
||||
// When true, allows gradient dtype to be different from tensor dtype,
|
||||
// bypassing dtype casting and validation in the autograd engine.
|
||||
bool allow_grad_dtype_mismatch_{false};
|
||||
|
||||
// Mutex to ensure that concurrent read operations that modify internal
|
||||
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
|
||||
// fw_grad() and set_fw_grad()
|
||||
@ -301,12 +293,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
uint64_t level,
|
||||
bool is_inplace_op) override;
|
||||
|
||||
std::optional<at::ScalarType> grad_dtype(const at::TensorBase& self) const;
|
||||
|
||||
void set_grad_dtype(
|
||||
const std::optional<at::ScalarType>& grad_dtype,
|
||||
const at::TensorBase& self);
|
||||
|
||||
AutogradMeta(
|
||||
at::TensorImpl* self_impl = nullptr,
|
||||
bool requires_grad = false,
|
||||
@ -954,11 +940,6 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet dispatch_keys,
|
||||
torch::jit::Stack* stack) const override;
|
||||
std::optional<c10::ScalarType> grad_dtype(
|
||||
const at::TensorBase&) const override;
|
||||
void set_grad_dtype(
|
||||
const at::TensorBase&,
|
||||
const std::optional<c10::ScalarType>&) const override;
|
||||
};
|
||||
|
||||
namespace utils {
|
||||
|
Reference in New Issue
Block a user