If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad` -> circular dependency with (1)!
Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.
----
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230)) Enum instead of bool.
- To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
- To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
The generated bytecode for the first frame is below. Inlined comments about the LOAD_ATTR which causes Dynamo to trigger again on `__getattr__`.
~~~
[__bytecode] MODIFIED BYTECODE fn /data/users/anijain/pytorch2/test/dynamo/test_activation_checkpointing.py line 1129
[__bytecode] 1129 0 COPY_FREE_VARS 1
[__bytecode] 2 RESUME 0
[__bytecode] 4 PUSH_NULL
[__bytecode] 6 LOAD_GLOBAL 10 (__compiled_fn_1)
[__bytecode] 18 LOAD_FAST 0 (x)
[__bytecode] 20 LOAD_DEREF 1 (mod)
[__bytecode] 22 LOAD_ATTR 6 (_checkpoint_wrapped_module)
[__bytecode] 32 LOAD_CONST 1 (0)
[__bytecode] 34 BINARY_SUBSCR
[__bytecode] 44 LOAD_ATTR 7 (weight)
[__bytecode] 54 LOAD_DEREF 1 (mod)
[__bytecode] 56 LOAD_ATTR 6 (_checkpoint_wrapped_module)
[__bytecode] 66 LOAD_CONST 1 (0)
[__bytecode] 68 BINARY_SUBSCR
[__bytecode] 78 LOAD_ATTR 8 (bias)
# When this optimized bytecode is executed, these two lines call the __getattr__ of ActivationWrapper module.
# Dynamo gets invoked on __getattr__.
# If we had inlined __getattr__ during the tracing, we would have seen the LOAD_ATTR
# on more low level data structures like _modules, obviating the need for CPython
# to call python overriden __getattr__. But today, UnspecializedNNModuleVariable
# calls python getattr at tracing time (instead of inlining it), resulting in LOAD_ATTR
# on the module itself.
# To prevent Dynamo to skip tracing of __Getattr__ on the optimized bytecode,
# we can check if its top level frame and just skip it.
[__bytecode] 88 LOAD_DEREF 1 (mod)
[__bytecode] 90 LOAD_ATTR 0 (a)
[__bytecode] 100 PRECALL 4
[__bytecode] 104 CALL 4
[__bytecode] 114 UNPACK_SEQUENCE 1
[__bytecode] 118 RETURN_VALUE
~~~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127263
Approved by: https://github.com/yf225
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125
Approved by: https://github.com/Skylion007
ghstack dependencies: #127122, #127123, #127124
Fixes https://github.com/pytorch/pytorch/issues/111020
For the following code:
```python
import torch
import torch._higher_order_ops.wrap
glob = []
def f(x):
glob.append(x)
return x.clone()
@torch.compile(backend='eager', fullgraph=True)
def g(x):
return torch.ops.higher_order.wrap(f, x)
x = torch.randn(3)
g(x)
```
The stacktrace now becomes:
```
[2024-02-01 15:23:34,691] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting wrap, we were unable to trace function `f` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 381, in speculate_subgraph
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] output = f.call_function(tx, args, sub_kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 278, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_function(tx, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 86, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_user_function_return(
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2370, in inline_call_
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return inner_fn(self, inst)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.call_function(fn, args, {})
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.push(fn.call_function(self, args, kwargs))
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/misc.py", line 583, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return self.obj.call_method(tx, self.name, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 330, in call_method
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_method(tx, name, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 241, in call_method
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] tx.output.side_effects.mutation(self)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 325, in mutation
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.check_allowed_side_effect(var)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 157, in check_allowed_side_effect
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] unimplemented(
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] raise Unsupported(msg)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test.py", line 219, in <module>
g(x)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 615, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 390, in _convert_frame_assert
return _compile(
File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 650, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 248, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 155, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 496, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2125, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1227, in call_function
p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1190, in create_wrapped_node
) = speculate_subgraph(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 453, in speculate_subgraph
raise ex
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 381, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 278, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 86, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2370, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/misc.py", line 583, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 330, in call_method
return super().call_method(tx, name, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 241, in call_method
tx.output.side_effects.mutation(self)
File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 325, in mutation
self.check_allowed_side_effect(var)
File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 157, in check_allowed_side_effect
unimplemented(
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
from user code:
File "/home/yidi/local/pytorch/test.py", line 216, in g
return torch.ops.higher_order.wrap(f, x)
File "/home/yidi/local/pytorch/test.py", line 211, in f
glob.append(x)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118826
Approved by: https://github.com/yanboliang, https://github.com/zou3519
Don't require using it as `@requires_cuda()` -> `@requires_cuda` instead No need for the partial function invoked many times
Split out this change from the initial large refactoring in #117741 to hopefully get merged before conflicts arise
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118281
Approved by: https://github.com/ezyang
# Context
In some cases, we might want to build the `context_fn` with runtime-defined policies. One way of implementing this is to make `context_fn` be a partial, which holds the information that we want to pass. One concrete example is the [automatic policy selection from `xformers`](ad986981b1/xformers/checkpoint.py (L185)).
# The problem
The previous implementation wouldn't work with partials because `FunctoolsPartialVariable` doesn't have a `fn` attribute.
This PR addresses this case, but ideally we could get this solved in a more general fashion, as callable classes and `NestedUserFunctionVariable` are not supported by this PR.
# Tests
I've added a basic test that mimics the tests around it. The tests could probably be simplified, but I've decided to keep changes to a minimum.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117657
Approved by: https://github.com/yf225
as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
Fixes https://github.com/pytorch/pytorch/issues/113717.
When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.
But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
NOTE: this PR is tagged "not user facing", because it's not ready to be announced externally yet.
This PR implements torch.compile + selective activation checkpoint (SAC) integration, by using `TagActivationCheckpoint` (same backend as torch.compile + full activation checkpoint integration).
TorchDispatchMode based implementation cannot support including inplace ops in the checkpointed region at the moment (the reason for this needs investigation), and there is also no way to ban them (because TorchDispatchMode now only sees "after-functionalization" ops, so can't detect if an op is in-place). Hence we hide torch.compile + SAC behind a flag (`torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint`) and will only use it internally for cases that are known to not have in-place ops. This state won't last too long, because in-place op will at least be able to be detected after Brian's mode reordering and related functionalization changes.
So next steps after this PR:
1. Wait for Brian's mode reordering and related functionalization changes to land, and then try to enable the "inplace ops" unit test for torch.compile + selective activation checkpoint (if it doesn't work, investigate why).
2. Unify selective- and full-checkpoint under TorchDispatchMode based implementation.
Differential Revision: D47497145
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
Approved by: https://github.com/anijain2305