mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support setting grad_dtype on leaf tensors (#162815)
`grad_dtype` is a new attribute on Tensor to control gradient dtype: - Access/setting is leaf-only. - grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details) - Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype` - `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict. | `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine | |-----------------------|--------------------------|-----------------------------------------| | **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype | | **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype | | **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is | Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
43848b71d9
commit
dca73982c5
@ -326,6 +326,8 @@ class TestTensorBuiltins(JitTestCase):
|
||||
# This has a longer implementation, maybe not worth copying to
|
||||
# TorchScript if named tensors don't work there anyways
|
||||
"names",
|
||||
# We don't plan to support grad_dtype in TorchScript
|
||||
"grad_dtype",
|
||||
}
|
||||
|
||||
for p in properties:
|
||||
|
Reference in New Issue
Block a user