Fixes https://github.com/pytorch/pytorch/issues/133858
Details: Previously Dynamo would treat dataclasses as UserDefinedVariables. This was non-desirable if we would like to proxy the value into the graph, which is needed for TensorSubclassMetadata. To rectify this, frozen dataclasses are now able to be proxied similarly to NamedTuples. We require the object to be frozen, because if arbitrary mutation were allowed, we would need to replay those mutations in the graph after construction of the object.
For tracing construction of the variable, the generated `__init__` for the dataclass uses `object.__setattr__` because frozen dataclasses throw errors on the usual `__setattr__` invocation. With this treatment, no special handling is needed in dynamo for frozen dataclass construction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134846
Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
This UT actual code only one empty line wrap difference(`linear` and `add`) between Windows/Linux, and the context is right.
Reproduce UTs:
```cmd
pytest test\dynamo\test_higher_order_ops.py -v -k test_functional_call_sequential_params_and_buffers
```
We can add `empty_line_normalizer` to fix it.
```cmd
______________________________________________________________________________________________ FuncTorchHigherOrderOpTests.test_functional_call_sequential_params_and_buffers _______________________________________________________________________________________________
Traceback (most recent call last):
File "D:\xu_git\dnnl_cb\pytorch\test\dynamo\test_higher_order_ops.py", line 3676, in test_functional_call_sequential_params_and_buffers
self.assertExpectedInline(
File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\site-packages\torch\testing\_internal\common_utils.py", line 2871, in assertExpectedInline
return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\site-packages\expecttest\__init__.py", line 271, in assertExpectedInline
self.assertMultiLineEqualMaybeCppStack(expect, actual, msg=help_text)
File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\site-packages\expecttest\__init__.py", line 292, in assertMultiLineEqualMaybeCppStack
self.assertMultiLineEqual(expect, actual, *args, **kwargs)
File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\unittest\case.py", line 1226, in assertMultiLineEqual
self.fail(self._formatMessage(msg, standardMsg))
File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\unittest\case.py", line 675, in fail
raise self.failureException(msg)
AssertionError: 'clas[509 chars]one\n add: "f32[1, 1]" = linear + l_buf[69 chars],)\n' != 'clas[509 chars]one\n\n add: "f32[1, 1]" = linear + l_b[71 chars],)\n'
class GraphModule(torch.nn.Module):
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
l_params_l1_weight_ = L_params_l1_weight_
l_params_l1_bias_ = L_params_l1_bias_
l_buffers_buffer_ = L_buffers_buffer_
l_inputs_ = L_inputs_
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
+ <<<< (difference is here )
add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None
return (add,)
: To accept the new output, re-run test with envvar EXPECTTEST_ACCEPT=1 (we recommend staging/committing your changes before doing this)
To execute this test, run the following from the base repo dir:
python test\dynamo\test_higher_order_ops.py FuncTorchHigherOrderOpTests.test_functional_call_sequential_params_and_buffers
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
========================================================================================================================== short test summary info ==========================================================================================================================
FAILED [0.4275s] test/dynamo/test_higher_order_ops.py::FuncTorchHigherOrderOpTests::test_functional_call_sequential_params_and_buffers - AssertionError: 'clas[509 chars]one\n add: "f32[1, 1]" = linear + l_buf[69 chars],)\n' != 'clas[509 chars]one\n\n add: "f32[1, 1]" = linear + l_b[71 chars],)\n'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134394
Approved by: https://github.com/jansel
Co-authored-by: Jason Ansel <jansel@jansel.net>
----
- We now record on CacheEntry what the compile id that populated it was, so now we can say why a specific frame was rejected
- Add structured log for recompiles under name artifact "recompile_reasons". As it stands, it's not terribly structured, but this was the easiest thing I could do to start
- Slightly reformat multi-reason printing; since we only report one guard failure seems better to have it as a single line
Example output:
```
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] Recompiling function f in /data/users/ezyang/a/pytorch/b.py:3
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] triggered by the following guard failure(s):
V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] - 0/0: tensor 'L['x']' size mismatch at index 0. expected 4, actual 5
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130043
Approved by: https://github.com/anijain2305
This will be helpful in reducing some of the hardcoded and python-version-dependent bytecode generation in various places in dynamo - e.g. resume function generation and object reconstruction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127359
Approved by: https://github.com/jansel
ghstack dependencies: #127329
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.
This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.
I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:
```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```
Potential later follow ups:
* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
aten.div's output device will be its numerator's device. so it is acceptable to do cuda / cpu type divisions. post grad passes operate only on graphs and can't handle runtime graph inputs. so we change user code to move inputs to cuda for cudagraph. this affects any graph that has cpu tensors as graph inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119729
Approved by: https://github.com/eellison
aten.div's output device will be its numerator's device. so it is acceptable to do cuda / cpu type divisions. post grad passes operate only on graphs and can't handle runtime graph inputs. so we change user code to move inputs to cuda for cudagraph. this affects any graph that has cpu tensors as graph inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119729
Approved by: https://github.com/eellison
Fixes https://github.com/pytorch/pytorch/issues/119607 for 3.11+.
In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.
Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124238
Approved by: https://github.com/jansel
Fixes#114844
In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...) # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.
This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.
Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
Python 3.12 changed a few things with how `_PyInterpreterFrame`s are allocated and freed:
- Frames are now required to be placed on the Python frame stack. In 3.11, we could allocate frames anywhere in memory. In 3.12, we now need to use `THP_PyThreadState_BumpFramePointerSlow`/`push_chunk`/`allocate_chunk`. This method of allocating/freeing frames is also compatible with 3.11.
- The eval frame function is now responsible for clearing the frame (see https://docs.python.org/3/whatsnew/changelog.html#id128, the point about "...which now clear the frame.")
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122146
Approved by: https://github.com/jansel
Fixes#114844
In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...) # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.
This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
Fixes#114844
In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...) # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.
This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
Fixes https://github.com/pytorch/pytorch/issues/119238
Here's what it looks like now:
```
$ TORCH_LOGS=+torch._dynamo.convert_frame python a.py
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG] torchdynamo start compiling f /data/users/ezyang/b/pytorch/a.py:3, stack (elided 5 frames):
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG] File "/data/users/ezyang/b/pytorch/a.py", line 7, in <module>
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG] f(torch.randn(2))
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 453, in _fn
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG] return fn(*args, **kwargs)
[2024-02-05 18:52:07,248] [0/0] torch._dynamo.convert_frame: [DEBUG]
$ cat a.py
import torch
@torch.compile
def f(x):
return x * 2
f(torch.randn(2))
```
The eval_frame frame is intentionally present, since what happens is you run the torch.compile wrapper, and then you actually hit the user frame to be compiled.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119251
Approved by: https://github.com/yanboliang, https://github.com/mlazos
The `initial` argument in `functools.reduce` can be `None`.
```python
initial_missing = object()
def reduce(function, iterable, initial=initial_missing, /):
it = iter(iterable)
if initial is initial_missing:
value = next(it)
else:
value = initial
for element in it:
value = function(value, element)
return value
```
Reference:
- python/cpython#102759
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116398
Approved by: https://github.com/Skylion007
Fix https://github.com/pytorch/pytorch/issues/109736 .
HF pin move causes regression on accuracy check for HF models on the dashboard. Manually reverting the HF PR ( https://github.com/huggingface/transformers/pull/24696/files ) could recover, but this may hide some real issue. I happen to found that using a warm matmul max-autotune cache can work around the issue. Or putting it in another way:
- make all calls to check_cache cache miss repro the issue
- make all cals to check_cache cache hit works around the issue
I did some sort of 'bisect' to force halving the amount of cache miss each time while still make sure we can repro. Luckily reducing to a single cache miss still repro the issue. With more debugging, it turns out that it's the call to `torch.randn` on cuda device causing the problem.
The fix is to make sure we restore the rng state when we generate random inputs for max-autotune benchmarking.
TBH, I can not fully explain the root cause although I know it's caused by rng state change. AOTAutograd already has some logic to preserve rng state. And I can not repro the issue in unit tests. I have a few guess why the RNG state is not restored in the first place after we generate random inputs for max-autotune:
- maybe AOTAutograd misses some corner case to preserve the rng state
- maybe for the failed models, there are some eager fallback that's not handled by inductor. And if those fallback calles random number related APIs, we will see the issue. But again I don't find a good way to simulate this.
Repro:
```
TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 CUDA_VISIBLE_DEVICES=3 time python benchmarks/dynamo/huggingface.py --backend inductor --amp --accuracy --only PLBartForCausalLM --training --cold-start-latency
```
We always repro the issue without the PR but pass the accuracy check with the PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109828
Approved by: https://github.com/eellison
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/
We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.
In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.
Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.
All the tests in `tests/torch_np` take about 75s to run.
This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
`TensorMeta.from_irnodes` handles either a single `IRNode` or a tuple or list of them. I tried to express this with overloading, but because this file is in MYPYNOFOLLOW, the `IRNode` subclasses become `Any`, which causes the overloads to be overlapping.
This changes the type of the argument to `benchmark_in_sub_process` to the more specific `TritonTemplateCaller`, since that one has the `bmreq` member and existing docstrings indicate that only the triton template benchmark is handled.
The `rand_strided` call caused a mypy error because the default value for device was a string. This is fixed by adding type hints to `rand_strided` in `torch/_dynamo/testing.py`. Likewise, the return value of `PyCodeCache.load_by_key_path` can be inferred from the type hint on `PyCodeCache.cache`.
Fixes one part of #105230
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105791
Approved by: https://github.com/jansel, https://github.com/Skylion007
Some notes:
* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103623
Approved by: https://github.com/voznesenskym
Added two signpost_event calls to torch.fx.experimental.symbolic_shapes, one for produce_guards (where we can give stats like how many free symbols and how many guards produced) and the other is for evaluate_expr after freeze (so we can look for cases where we're improperly discarding guards in backwards.)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103882
Approved by: https://github.com/Skylion007