Files
pytorch/torch/_dynamo/create_parameter_op.py
Jason Ansel 0b7d9711d4 [dynamo] Add support for nn.Parameter constructor (part 2) (#120965)
This handles the case where the tensor isn't an input.

The changes to dynamo tests are cases where we would previously fall back to eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120965
Approved by: https://github.com/yanboliang
ghstack dependencies: #121735
2024-03-16 04:29:58 +00:00

51 lines
2.0 KiB
Python

import torch
from torch._prims import _make_prim, RETURN_TYPE
from torch._prims_common import clone_preserve_strides
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
to flow into the parameter as if it were an input to the graph (which is the only thing we are
allowed to compute gradients on).
""".strip()
_bind_nn_parameter = _make_prim(
schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor",
return_type=RETURN_TYPE.NEW,
meta=lambda self, placeholder: torch.nn.Parameter(
clone_preserve_strides(self), placeholder.requires_grad
),
impl_aten=lambda self, placeholder: placeholder.set_(self),
doc=doc,
)
torch.fx.node.has_side_effect(_bind_nn_parameter)
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, placeholder):
assert not tensor.requires_grad
return _bind_nn_parameter(tensor, placeholder)
@staticmethod
def backward(ctx, grad):
return None, grad # grad flows to placeholder
def tracable_create_parameter(tensor, placeholder):
with torch.set_grad_enabled(placeholder.requires_grad):
return TracableCreateParameter.apply(tensor, placeholder)
def new_parameter_placeholder(size, dtype, device, requires_grad):
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
)
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result