mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This handles the case where the tensor isn't an input. The changes to dynamo tests are cases where we would previously fall back to eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120965 Approved by: https://github.com/yanboliang ghstack dependencies: #121735
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
import torch
|
|
from torch._prims import _make_prim, RETURN_TYPE
|
|
from torch._prims_common import clone_preserve_strides
|
|
|
|
doc = """
|
|
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
|
|
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
|
|
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
|
|
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
|
|
to flow into the parameter as if it were an input to the graph (which is the only thing we are
|
|
allowed to compute gradients on).
|
|
""".strip()
|
|
|
|
_bind_nn_parameter = _make_prim(
|
|
schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor",
|
|
return_type=RETURN_TYPE.NEW,
|
|
meta=lambda self, placeholder: torch.nn.Parameter(
|
|
clone_preserve_strides(self), placeholder.requires_grad
|
|
),
|
|
impl_aten=lambda self, placeholder: placeholder.set_(self),
|
|
doc=doc,
|
|
)
|
|
torch.fx.node.has_side_effect(_bind_nn_parameter)
|
|
|
|
|
|
class TracableCreateParameter(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, tensor, placeholder):
|
|
assert not tensor.requires_grad
|
|
return _bind_nn_parameter(tensor, placeholder)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return None, grad # grad flows to placeholder
|
|
|
|
|
|
def tracable_create_parameter(tensor, placeholder):
|
|
with torch.set_grad_enabled(placeholder.requires_grad):
|
|
return TracableCreateParameter.apply(tensor, placeholder)
|
|
|
|
|
|
def new_parameter_placeholder(size, dtype, device, requires_grad):
|
|
"""Create a placeholder to be passed to the above functions"""
|
|
result = torch.nn.Parameter(
|
|
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
|
|
)
|
|
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
|
|
# Allocating a zero tensor would causes assert failures in autograd.
|
|
result.untyped_storage().resize_(0)
|
|
return result
|