Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720
Approved by: https://github.com/zou3519
We will always proxy autograd.Function nodes in compiled autograd's
initial graph capture (previously there was an
option to proxy vs trace into the autograd.Function)
We have some requirements for the AOTBackward. Compiled Autograd runs
accumulate grad reordering passes on the AOTBackward graph directly
after the initial graph capture, so we can't just proxy a single node for it.
Instead, we:
- proxy the AOTBackward prologue function into the CA graph
- copy-paste the AOTBackward graph into the CA graph
- trace directly through the epilogue (the traced nodes go into the CA
graph).
Tracing through the epilogue is safe (assuming no Tensor subclasses)
because the only thing the epilogue does is drop some outputs. The
Tensor subclass situation was already broken so this doesn't regress
anything but this PR sets it up to be fixed (in a followup, where we
will proxy "make_subclass" calls into the graph from the epilogue).
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143405
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387
This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.
This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.
We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.
Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
call to Python via libtorch_python. There is an indirection
(PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
our lives easier, every codegen'ed apply_with_saved passes a C++
function with the same signature
`(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
IValuePacker struct and codegen functional variants of each builtin
C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).
MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de
MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```
This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer, https://github.com/huydhn
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```
This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
Fixes https://github.com/pytorch/pytorch/issues/128544
Fixes https://github.com/pytorch/pytorch/issues/128535
We had a problem with multithreading where the nonlocals were being
clobbered. In the first place, we stored these nonlocals because we
wanted to ferry information from an autograd.Function.apply to
autograd.Function.forward.
Our new approach is:
- pass the information directly as an input to the
autograd.Function.apply. This means that the autograd.Function.forward
will receive the information too.
- this messes up ctx.needs_input_grad, which has an element per input to
forward. The user should not see the additional information we passed.
We fix this by temporarily overriding ctx.needs_input_grad to the
right thing.
- this exposed a bug in that ctx.needs_input_grad wasn't correct for
TensorList inputs. This PR fixes that too.
Test Plan:
- existing and new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128547
Approved by: https://github.com/williamwen42, https://github.com/soulitzer
- `FakeContext` hides all fields other than ctx.saved_tensors, this dynamo errors when the autograd.Function.backward uses other attrs on ctx and it also doesn't allow fallback to eager.
- If we remove it, we still can't fallback to eager: node variables are already freed (ctx.saved_tensors throws)
- However, we can fallback to "pseudo-eager" by using a duck-typed ctx and routing the ctx.saved_tensors to lifted tensors
- Dynamo tries to inline external_utils.call_backward, treats BackwardCFunction as a AutogradFunctionContextVariable (only used up until we create the fake context: FakeBackwardCFunction)
- we call_function backward from the forward class AutogradFunctionVariable, and we still pass in the fake context as a UserDefinedObjectVariable (can later use AutogradFunctionContextVariable + HOO graph speculate)
Fixes#125489#124827
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125661
Approved by: https://github.com/jansel
- Adds support for custom ops backed by c++ custom autograd functions, e.g. fbgemm
- Include files more granularly to avoid namespace pollution and circular imports
limitations:
- requires user to audit their code and opt-in their custom autograd::Function via autograd::Function::is_traceable and maybe additional compiled_args + apply_with_saved implementation. this was the only way I can think of for soundness
- will throw if we can't hash the saved_data i.e. for any non implemented type other than list and dict in at::IValue::hash b0cfa96e82/aten/src/ATen/core/ivalue.cpp (L364)
- can technically silently fail if both the typeid hash and the typeid string name of the custom autograd::Function collide at the same time, and an identical autograd graph containing a different custom autograd::Function, yet that has an identical implementation, is called. this case seems extremely unlikely, and the only alternative to hash collision i can think of is compiling with reflection
- tensors not saved via save_variables are not lifted, and are specialized on TensorImpl*'s hash (treated as a memory address). if needed, we can lift them.
Differential Revision: [D54818488](https://our.internmc.facebook.com/intern/diff/D54818488)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120681
Approved by: https://github.com/jansel
- Adds support for custom ops backed by c++ custom autograd functions, e.g. fbgemm
- Include files more granularly to avoid namespace pollution and circular imports
limitations:
- requires user to audit their code and opt-in their custom autograd::Function via autograd::Function::is_traceable and maybe additional compiled_args + apply_with_saved implementation. this was the only way I can think of for soundness
- will throw if we can't hash the saved_data i.e. for any non implemented type other than list and dict in at::IValue::hash b0cfa96e82/aten/src/ATen/core/ivalue.cpp (L364)
- can technically silently fail if both the typeid hash and the typeid string name of the custom autograd::Function collide at the same time, and an identical autograd graph containing a different custom autograd::Function, yet that has an identical implementation, is called. this case seems extremely unlikely, and the only alternative to hash collision i can think of is compiling with reflection
- tensors not saved via save_variables are not lifted, and are specialized on TensorImpl*'s hash (treated as a memory address). if needed, we can lift them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120681
Approved by: https://github.com/jansel
- Adds support for custom ops backed by c++ custom autograd functions, e.g. fbgemm
- Include files more granularly to avoid namespace pollution and circular imports
limitations:
- requires user to audit their code and opt-in their custom autograd::Function via autograd::Function::is_traceable and maybe additional compiled_args + apply_with_saved implementation. this was the only way I can think of for soundness
- will throw if we can't hash the saved_data i.e. for any non implemented type other than list and dict in at::IValue::hash b0cfa96e82/aten/src/ATen/core/ivalue.cpp (L364)
- can technically silently fail if both the typeid hash and the typeid string name of the custom autograd::Function collide at the same time, and an identical autograd graph containing a different custom autograd::Function, yet that has an identical implementation, is called. this case seems extremely unlikely, and the only alternative to hash collision i can think of is compiling with reflection
- tensors not saved via save_variables are not lifted, and are specialized on TensorImpl*'s hash (treated as a memory address). if needed, we can lift them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120681
Approved by: https://github.com/jansel
RECORD_FUNCTION in python_function only captures argument that is a Tensor. However, it is very common for user to use non tensor arguments in custom ops, for example, sequence length in GPT attention custom op. My previous PR tries to capture all non-tensor arguments, it turned out in some cases, it is very expensive.
This PR is to support primitive (or its container) arguments in RECORD_FUNCTION.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120949
Approved by: https://github.com/soulitzer
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)
We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
Summary: RECORD_FUNCTION only capture the argument when it is a Tensor. However, it is very common for user to use the argument with primitive data type (int, float, index, bool). This DIFF is to support non tensor arguments in RECORD_FUNCTION.
Test Plan:
unit test
buck test mode/dev-nosan caffe2/test:profiler -- test_execution_trace_with_pt2 test_execution_trace_alone test_execution_trace_with_kineto test_execution_trace_start_stop test_execution_trace_repeat_in_loop test_execution_trace_no_capture
Differential Revision: D53674768
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120017
Approved by: https://github.com/soulitzer
This PR adds support for torch.autograd.Function subclasses in compiled autograd. We do this by:
- Creating a uid for all torch.autograd.Function via its metaclass. This uid is used in the compiled autograd key, which is a subset of the cache key to the compiled graph
- "Lifting" the backward/saved_tensors, having them as input arguments in the compiled graph
- Creating proxies to track the backward's inputs and outputs. Since the backward's outputs (grads) have to match the forward's inputs, we pass the node's `input_info` (forward's input sizes) to build the proxies tracking the backward's outputs.
- Use a `FakeContext` class as a replacement for the autograd node's context object (`BackwardCFunction`) during tracing, only support passing saved_tensors from the forward to the backward
- Index each backward, to support multiple torch.autograd.Functions in the same graph
- Special case for `CompiledFunctionBackward`, lifting CompiledFunction will fail 4 tests and requires some skipfiles changes that I'd rather do that in a separate PR
Example graph: test_custom_fn_saved_multiple_tensors (eager fw + compiled autograd)
```python
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x), torch.sin(y)
@staticmethod
def backward(ctx, gO_x, gO_y):
(x, y) = ctx.saved_tensors
return gO_x * torch.cos(x), gO_y * torch.cos(y)
```
The backwards is lifted via `getitem_5` and `call_backward`
```python
# Compiled autograd graph
===== Compiled autograd graph =====
<eval_with_key>.0 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, hooks):
# No stacktrace found for following nodes
getitem: "f32[]" = inputs[0]
getitem_1: "f32[10]" = inputs[1]
getitem_2: "f32[10]" = inputs[2]
getitem_3: "f32[10]" = inputs[3]
getitem_4: "f32[10]" = inputs[4]; inputs = None
expand: "f32[10]" = torch.ops.aten.expand.default(getitem, [10]); getitem = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
getitem_5 = hooks[0]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3, getitem_4), mul_1, mul); getitem_5 = mul_1 = mul = None
getitem_6: "f32[10]" = call_backward[0]
getitem_7: "f32[10]" = call_backward[1]; call_backward = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
return []
```
then is later inlined by dynamo
```python
# Dynamo graph
===== __compiled_fn_0 =====
<eval_with_key>.1 class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor, L_inputs_1_ : torch.Tensor, L_inputs_2_ : torch.Tensor, L_inputs_3_ : torch.Tensor, L_inputs_4_ : torch.Tensor):
getitem = L_inputs_0_
getitem_1 = L_inputs_1_
getitem_2 = L_inputs_2_
x = L_inputs_3_
y = L_inputs_4_
# File: <eval_with_key>.0:10, code: expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
# File: <eval_with_key>.0:11, code: mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
# File: <eval_with_key>.0:12, code: mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
# File: /data/users/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py:412, code: return gO_x * torch.cos(x), gO_y * torch.cos(y)
cos = torch.cos(x)
getitem_6 = mul_1 * cos; mul_1 = cos = None
cos_1 = torch.cos(y)
getitem_7 = mul * cos_1; mul = cos_1 = None
# File: <eval_with_key>.0:17, code: accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(y, getitem_7); y = getitem_7 = None
# File: <eval_with_key>.0:18, code: accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(x, getitem_6); x = getitem_6 = None
return ()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115573
Approved by: https://github.com/jansel
Summary:
In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path.
For example:
```
tbe lookup -> a2a ->
overarch
dense ------------->
```
if the forward code is written as
a2a_out = a2a
dense = dense_net
out = overarch(a2a_out, dense)
out.backward()
The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap.
Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first.
Test Plan: Tests incoming
Differential Revision: D51445261
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120
Approved by: https://github.com/ezyang, https://github.com/albanD
The existing try-catch doesn't work because it doesn't call err.persist(). This is in contrast to the try-catch for evaluate_function which does work because it calls into python_engine's thread_on_exception which calls persist.
Calling persist on a python_error stashes the PyErr state from the thread-local PyThreadState onto the python_error object, so that when this error object is stored onto the future and passed back to the calling cpu thread, python_engine's execute try-catch can then err.restore() the error state. Finally, the python_engine's execute would re-raise so that this is re-caught by the HANDLE_TH_ERRORS macro.
Fixes https://github.com/pytorch/pytorch/issues/75750
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113702
Approved by: https://github.com/albanD
This can be useful for advanced users (like AOTAutograd) who don't want to keep the corresponding Tensor alive (for memory reasons for example) or when inplace op will change the Tensor's grad_fn (but gradients wrt to the original value is needed).
I went minimal API change but open to suggestions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110867
Approved by: https://github.com/soulitzer