Files
pytorch/torch/_dynamo/create_parameter_op.py
Xuehai Pan e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00

61 lines
2.0 KiB
Python

# mypy: allow-untyped-defs
import threading
from contextlib import contextmanager
import torch
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
to flow into the parameter as if it were an input to the graph (which is the only thing we are
allowed to compute gradients on).
""".strip()
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, placeholder):
assert not tensor.requires_grad
return placeholder.set_(tensor)
@staticmethod
def backward(ctx, grad):
return None, grad # grad flows to placeholder
def tracable_create_parameter(tensor, placeholder):
with torch.set_grad_enabled(placeholder.requires_grad):
out = TracableCreateParameter.apply(tensor, placeholder)
return out
def new_parameter_placeholder(size, dtype, device, requires_grad):
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
)
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result
_TLS = threading.local()
@contextmanager
def do_not_convert_to_tracable_parameter():
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
_TLS.convert_tracable_parameter = False
try:
yield False
finally:
_TLS.convert_tracable_parameter = old_flag
def can_convert_to_tracable_parameter():
return getattr(_TLS, "convert_tracable_parameter", True)