Summary:
This is a reimplemented version of the FB specific code in https://www.internalfb.com/diff/D54230697
The new strategy is that we unconditionally install an FB handler to trace_log logger (and always set level to DEBUG). When the first log message is emitted, we check the JK/filesystem to see if we should actually do logging. If we decide we don't do logging, we remove the handler from trace_log and are done.
build_only[github-export-checks,executorch,pytorch_benchmark,pytorch_quantization,pytorch_distributed,pytorch_distributed_gpu,pytorch_dynamo_inductor,pytorch_functorch,pytorch_fx2trt,pytorch_diff_train_tests_ads,glow_fb_pytorch_tests,training_platform,training_platform_compatibility,training_toolkit_applications,training_toolkit_examples,training_toolkit_model_optimization,dper3_pytorch,xplat_caffe2,pytorch_dev,android-pytorch-instrumentation-tests,smartpytorchgithub_first_try_merge,frl-target-determinator,f6-buck,training_platform_for_github,sigmoid_cpu,sigmoid_gpu,aiplatform_modelprocessing_for_github,accelerators_workloads_models_slimdsnn,ae_aotinductor_benchmark_test,aps_,aps_deterministic_ne_tests,dper_lib_silvertorch,torchrec,torchrec_fb,deeplearning_aot_inductor]
Test Plan:
sandcastle
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/inference/tests:test_single_gpu_executor -- --exact 'torchrec/inference/tests:test_single_gpu_executor - TorchDeployGPUTest.NestedModelSingleGPU'
buck2 test 'fbcode//mode/dev-nosan' fbcode//dper_lib/silvertorch/modules/dynamic_stats/tests:accumulators_test -- --exact 'dper_lib/silvertorch/modules/dynamic_stats/tests:accumulators_test - test_global_fixed_interval_accumulator (dper_lib.silvertorch.modules.dynamic_stats.tests.accumulators_test.GlobalFixedIntervalUnivalentAcculumatorTest)'
```
Also running a test flow with/without JK enabled
Differential Revision: D54275086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120915
Approved by: https://github.com/yanboliang
Shorthand for `"%(levelname)s:%(name)s:%(message)s"` which is hard to
remember.
I find the default formatter annoying since just the metadata fills up
most of the width of my terminal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120757
Approved by: https://github.com/ezyang
Overall design: https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
How to read the diff:
* Most files are me augmenting pre-existing logging with structured variants. For the most part it's simple (esp FX graphs, which have a canonical string representation); it gets more complicated when I decided to JSON-ify some data structure instead of keeping the ad hoc printing (notably, guards and dynamo output graph sizes)
* torch/_functorch/_aot_autograd/collect_metadata_analysis.py is some unrelated fixes I noticed while auditing artifact logs
* torch/_logging/_internal.py has the actual trace log implementation. The trace logger is implement as a logger named torch.__trace which is disconnected from the logging hierarchy. It gets its own handler and formatter (TorchLogsFormatter with _is_trace True). `trace_structured` is the main way to emit a trace log. Unusually, there's a separate "metadata" and "payload" field. The metadata field should not be too long (as it is serialized as a single line) and is always JSON (we put contextual things like compile id in it); the payload field can be long and is emitted after the metadata log line and can span multiple lines.
* torch/_logging/structured.py contains some helpers for converting Python data structures into JSON form. Notably, we have a string interning implementation here, which helps reduce the cost of serializing filenames into the log.
* test/dynamo/test_structured_trace.py the tests are cribbed from test_logging.py, but all rewritten to use expect tests on munged versions of what we'd actually output. Payloads are never tested, since they tend not be very stable.
https://github.com/ezyang/tlparse is a POC Rust program that can interpret these logs.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120289
Approved by: https://github.com/Skylion007
ghstack dependencies: #120712
Yuzhen Huang was complaining to me that searching for `__recompile`
no longer works. This is because the glog format is filename, not
logger name, so we lost the artifact name. Add it back.
Looks like:
```
V0226 15:56:04.142000 139828992779264 torch/_dynamo/guards.py:1084] [0/2] __guards: ___check_type_id(L['inputs'], 7626144)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120671
Approved by: https://github.com/Skylion007
Overall design: https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
How to read the diff:
* Most files are me augmenting pre-existing logging with structured variants. For the most part it's simple (esp FX graphs, which have a canonical string representation); it gets more complicated when I decided to JSON-ify some data structure instead of keeping the ad hoc printing (notably, guards and dynamo output graph sizes)
* torch/_functorch/_aot_autograd/collect_metadata_analysis.py is some unrelated fixes I noticed while auditing artifact logs
* torch/_logging/_internal.py has the actual trace log implementation. The trace logger is implement as a logger named torch.__trace which is disconnected from the logging hierarchy. It gets its own handler and formatter (TorchLogsFormatter with _is_trace True). There's a teensy bit of FB specific code to automatically enable trace logging if a /logs directory exists. `trace_structured` is the main way to emit a trace log. Unusually, there's a separate "metadata" and "payload" field. The metadata field should not be too long (as it is serialized as a single line) and is always JSON (we put contextual things like compile id in it); the payload field can be long and is emitted after the metadata log line and can span multiple lines.
* torch/_logging/structured.py contains some helpers for converting Python data structures into JSON form. Notably, we have a string interning implementation here, which helps reduce the cost of serializing filenames into the log.
* test/dynamo/test_structured_trace.py the tests are cribbed from test_logging.py, but all rewritten to use expect tests on munged versions of what we'd actually output. Payloads are never tested, since they tend not be very stable.
https://github.com/ezyang/tlparse is a POC Rust program that can interpret these logs.
Testing that the fbcode detection works at https://www.internalfb.com/mlhub/pipelines/runs/fblearner/534553450 (Meta-only)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120289
Approved by: https://github.com/Skylion007
Before:
```
[2024-02-13 19:34:50,591] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-02-13 19:34:50,591] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['x'], 70049616) # assert x.shape[0] > 2 # b.py:5 in f
[2024-02-13 19:34:50,592] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False # assert x.shape[0] > 2 # b.py:5 in f
```
After this change, the logs look like this:
```
V0214 07:00:49.354000 139646045393920 torch/_dynamo/guards.py:1023 [0/0] GUARDS:
V0214 07:00:49.354000 139646045393920 torch/_dynamo/guards.py:1039 [0/0] ___check_type_id(L['x'], 70050096) # assert x.shape[0] > 2 # b.py:5 in f
V0214 07:00:49.355000 139646045393920 torch/_dynamo/guards.py:1039 [0/0] hasattr(L['x'], '_dynamo_dynamic_indices') == False # assert x.shape[0] > 2 # b.py:5 in f
```
The main differences from what we had before:
* We don't print DEBUG/INFO/WARNING, instead, we only print a single character. DEBUG, somewhat oddly, maps to V, because it corresponds to glog VERBOSE
* The year is omitted, and a more compact representation for date/month is adopted. Somewhat perplexingly, six digits are allocated for the nanoseconds, even though Python typically doesn't have that level of resolution
* The thread ID is included (in a containerized environment, this thread id will be typically much lower)
* Instead of using the module name, we give a filepath, as well as the line the log message was emitted from. I think the line number is a nice touch and improvement over our old logs, but one downside is we do lose the artifact name in the log message, in case anyone was grepping for that.
* I chose to move the compile id prefix to the very end so as to keep a uniform layout before it, but I do think there are benefits to having it before the filename
Meta only: This format was reverse engineered off of 6b8bbe3b53/supervisor/logging.py and https://www.internalfb.com/code/fbsource/[e6728305a48540110f2bdba198aa74eee47290f9]/fbcode/tupperware/front_end/log_reader/filter/StreamingLogLineFilter.cpp?lines=105-114
Now, I think this may be slightly controversial, but I have chosen to apply this format *by default* in OSS. My reasoning is that many PT2 developers work with the logs in OSS, and keeping the format identical to what we run in prod will make it easier for these skills to transfer.
The non-negotiable portion of the new format is "V0213 19:28:32"; the date string is expected to be in exactly this form or Tupperware will fail to parse it as a date.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119869
Approved by: https://github.com/oulgen, https://github.com/mlazos, https://github.com/Skylion007
Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
This PR is the start to enable the integrate pytorch distributed logs in Torch LOGs. We now already have one tag "distributed" for all distributed components but distributed is a very large component and we want to have some hierarchy and give users options to only turn on logs for certain submodules. So we also added tags starting with "dist_*" for each submodule. (This PR only adds some of them and we are going to add more down the road)
Related discussions can be found here: https://github.com/pytorch/pytorch/issues/113544
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116434
Approved by: https://github.com/awgu, https://github.com/wanchaol
There are now 3 ways to see logs from ddpoptimzer.
1) TORCH_LOGS="distributed"
2) TORCH_LOGS="dynamo"
3) TORCH_LOGS="torch._dynamo.backends.distributed"
(1 and 2 are different supersets of 3 that also include other content)
Note: ddp_graphs is still a separate 'artifact' logger, which just
includes graph dumps from the graph-splitting process.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114376
Approved by: https://github.com/wanchaol
Followup to https://github.com/pytorch/pytorch/pull/110325 - re-add the `report_all_guard_failures config` as a logging artifact `recompiles_verbose` with the following changes:
- evaluating the check must be wrapped with exception handling because subsequent code parts following the first failure may result in errors if evaluated (e.g. if a guard checks first for size, then tries to index - a guard failure due to insufficient size would result in an index error for the latter check).
- Adding a test for this case
Sample:
```python
import torch
def fn(x):
return torch.rand(x[-1], len(x))
opt_fn = torch.compile(fn)
opt_fn([4, 5, 6])
opt_fn([7, 8])
opt_fn([9])
```
Output (with `TORCH_LOGS="recompiles_verbose"`):
```bash
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][1] == 5
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 2
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG]
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 1 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113585
Approved by: https://github.com/jon-chuang, https://github.com/ezyang
When SymNode was refactored into its own module, this broke logging for this file, as the `dynamic` alias no longer covered it. This PR adds supports for an alias to point to multiple qualified module names. To drive the refactor, I renamed `log_alias_to_log_qname` to `log_alias_to_log_qnames` and then audited all use sites. I invite you to do so as well.
For good measure, I also add dynamic to dynamo, so that I always get dynamic logs when dynamo is enabled. Empirically this will be helpful because people keep sending me dynamo debug logs that don't have dynamic logs.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113567
Approved by: https://github.com/Skylion007, https://github.com/lezcano, https://github.com/mlazos
ghstack dependencies: #113566
Previously, the way our logging system worked was that for every registered log, we would explicit set a log level for it. This would lead to unintuitive behavior when you had multiple overlapping loggers, e.g., from the module hierarchy. Specifically, if you had `TORCH_LOGS=torch`, this would not actually set the logging level for torch._dynamo to be INFO, because the default log level is WARNING, and because torch._dynamo has a registered logger 'dynamo' this would end up setting the level on torch._dynamo to be WARNING, thereby overriding the level of the parent module torch. The 'all' logger did not display this behavior, but only because it was special cased to directly modify the default log level of all other loggers (and so this would not work for any sub-hierarchies).
This PR refactors our code into a much more logical setup using NOTSET. Instead of setting the level of all loggers to some level, we instead default all loggers to NOTSET, unless a user explicitly requested logging from some logger. This means that if we have some logger which isn't explicitly mentioned by the user, parent loggers now have a chance to influence their log behavior. With this, I can eliminate the 'all' special case; 'all' really just means 'torch'. (I keep special handling directing all to torch for backwards compatibility, though arguably I can probably just turn all into an alias.)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113566
Approved by: https://github.com/mlazos, https://github.com/Skylion007
This PR will cause logging.exception() to also dump the exception and stacktrace. Copied from 74723e1110/Lib/logging/__init__.py (L707-L711)
repro:
<details>
```python
import torch
import torch._inductor.config
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "runtime_error"
def fn(x, y):
return (x @ y).relu()
x, y = [torch.rand((16, 16), device='cuda') for _ in range (2)]
torch.compile(fn)(x, y)
```
run with TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4
</details>
before:
```
...
[2023-10-12 14:18:52,902] torch._dynamo.debug_utils: [ERROR] While minifying the program in accuracy minification mode, ran into a runtime exception which is likely an unrelated issue. Skipping this graph.
```
now:
```
...
[2023-10-12 14:18:52,902] torch._dynamo.debug_utils: [ERROR] While minifying the program in accuracy minification mode, ran into a runtime exception which is likely an unrelated issue. Skipping this graph.
Traceback (most recent call last):
File "/data/users/dberard/scripts/relu_accuracy_issue.py", line 10, in <module>
torch.compile(fn)(x, y)
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111164
Approved by: https://github.com/eellison
Now looks like:
```
[2023-09-08 06:04:48,532] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_ATTR foo [ConstantVariable(int), NNModule
Variable()]
[2023-09-08 06:04:48,532] [0/0] torch._dynamo.convert_frame: [INFO]
Restarting analysis due to _dynamo/variables/nn_module.py
:138 in convert_to_unspecialized
[2023-09-08 06:04:48,533] [0/0_1] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing f /data/users/ezyang/c/pytorch/a.py:6
[2023-09-08 06:04:48,533] [0/0_1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line f /data/users/ezyang/c/pytorch/a.py:6
```
I'm happy to bikeshed the exact formatting of the attempt number if you
want.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108864
Approved by: https://github.com/mlazos, https://github.com/voznesenskym
All log messages that occur while running Dynamo compilation now have `[X/Y]` added to the beginning of their message. X represents the frame being compiled, while Y says which compilation of the frame. For example, if you are debugging a frame that is repeatedly recompiling, you can look for N/0, N/1, N/2, etc. for the same N. Here is what the logs look like as you transition from one frame to another:
<img width="1372" alt="image" src="https://github.com/pytorch/pytorch/assets/13564/4897e368-1e50-4807-b342-54e911bcf087">
To accurately get this prefix added to all messages, I had to expand the scope of the `tracing` context manager. Its scope now coincides with `log_compilation_event`. To do this, I had to populate fake mode lazily in the TracingContext, since it isn't created until later, inside the OutputGraph.
This subsumes the previous X.Y logging that was solely for dynamic shapes.
Unfortunately I had to reindent some stuff. Review the diff with whitespace off.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107530
Approved by: https://github.com/anijain2305
ghstack dependencies: #107505, #107516
It looks like this:
```
[DEBUG] GUARD: ___check_type_id(L['z'][L["MyEnum"].BAR], 7640416) and L['z'][L["MyEnum"].BAR] == 10
[DEBUG] Stack:
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 6657, in <module>
[DEBUG] run_tests()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/test_case.py", line 38, in run_tests
[DEBUG] run_tests()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 985, in run_tests
[DEBUG] unittest.main(argv=argv)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/main.py", line 101, in __init__
[DEBUG] self.runTests()
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/main.py", line 271, in runTests
[DEBUG] self.result = testRunner.run(self.test)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/runner.py", line 184, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 650, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2521, in run
[DEBUG] self._run_with_retry(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2450, in _run_with_retry
[DEBUG] super_run(result=result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 591, in run
[DEBUG] self._callTestMethod(testMethod)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
[DEBUG] method()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2377, in wrapper
[DEBUG] method(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 2529, in test_enum_as_dict_key_with_overloaded_str
[DEBUG] res = opt_fn(x)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 333, in _fn
[DEBUG] return fn(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 2519, in fn
[DEBUG] torch._dynamo.graph_break()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 493, in catch_errors
[DEBUG] return callback(frame, cache_size, hooks, frame_state)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 637, in _convert_frame
[DEBUG] result = inner_convert(frame, cache_size, hooks, frame_state)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 133, in _fn
[DEBUG] return fn(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 371, in _convert_frame_assert
[DEBUG] return _compile(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 567, in _compile
[DEBUG] guarded_code = compile_inner(code, one_graph, hooks, transform)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 181, in time_wrapper
[DEBUG] r = func(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 466, in compile_inner
[DEBUG] out_code = transform_code_object(code, transform)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
[DEBUG] transformations(instructions, code_options)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 416, in transform
[DEBUG] tracer = InstructionTranslator(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2018, in __init__
[DEBUG] self.symbolic_locals = collections.OrderedDict(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2021, in <genexpr>
[DEBUG] VariableBuilder(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 211, in __call__
[DEBUG] vt = self._wrap(value).clone(**self.options())
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 404, in _wrap
[DEBUG] result = {
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 405, in <dictcomp>
[DEBUG] k: VariableBuilder(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 211, in __call__
[DEBUG] vt = self._wrap(value).clone(**self.options())
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 354, in _wrap
[DEBUG] return type_dispatch(self, value)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 837, in wrap_literal
[DEBUG] return self.wrap_unspecialized_primitive(value)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1073, in wrap_unspecialized_primitive
[DEBUG] guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 269, in make_guards
[DEBUG] return {source.make_guard(guard) for guard in guards}
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 269, in <setcomp>
[DEBUG] return {source.make_guard(guard) for guard in guards}
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_guards.py", line 641, in make_guard
[DEBUG] return Guard(self.name(), self.guard_sou
```
One downside is I can't report *why* the guard was added. I'm not entirely sure how to do this; the problem is guards will propagate to a bunch of variables before finally getting included as part of the final set. Maybe a very very verbose version could report stack traces at every handoff point.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107388
Approved by: https://github.com/mlazos
ghstack dependencies: #107438, #107358
This adds some utilities for conveniently working with fast combined CapturedTraceback from Python. The main goal of these utilities is to make it easier for people to use CapturedTraceback as a drop-in replacement for `traceback.extract_stack`, which is 20x slower than CapturedTraceback.
I port symbolic shapes to use the new CapturedTraceback code, to validate that the APIs work and are useful.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107358
Approved by: https://github.com/zdevito, https://github.com/albanD
ghstack dependencies: #107438
Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.
Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond
def h(x):
return x * 5
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f(pred, x):
x = h(
h(h(x))
)
x = x[1:][:2]
torch._dynamo.graph_break()
x = cond(pred, true_fn, false_fn, [x])
opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```
Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
x = h(
~^
h(h(x))
^^^^^^^
)
^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
return x * 2
~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
return x * 3
~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang