Files
pytorch/torch/fx/experimental/_config.py
Avik Chaudhuri 8db8ac700d line by line logging (#134298)
Summary:
Today there is no good mechanism to detect progress of non-strict export line-by-line in user code. This caused some pain recently in trying to find the exact line of user code that was triggering a bug where the process appeared stuck because deep down something was calling some symbolic shapes code that was suffering some exponential blowup.

This PR adds a environment variable for extended debugging that will log the line of user code corresponding to every torch function call. It only works in non-strict export for now. Prefix setting this environment variable with `TORCH_LOGS`  enabled for `export` logs at `DEBUG` level (i.e., with a `+` prefix), i.e.,.:

```
TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
```

This will show logs with something like:
```
...
prim::device called at .../example.py:4284 in foo
TensorBase.item called at .../example.py:4277 in bar
...
```

We already have an existing place to intercept torch functions where we process data-dependent errors in non-strict, so parking the logging there. An alternative place we could be doing this is where we add `stack_trace` metadata when generating code, but unfortunately at least the example that motivated this gets stuck before generating code, so that would be too late.

Test Plan: ran it on some sample commands

Differential Revision: D61692156

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134298
Approved by: https://github.com/angelayi
2024-08-25 02:57:11 +00:00

89 lines
3.7 KiB
Python

import os
import sys
from typing import Optional
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
# Disables bisection for translation validation.
#
# Translation validation bisection is enabled by default, if translation validation
# is also enabled. This should help finding guard simplification issues. However,
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
# the a ShapeEnv with the same state. This should be used only in testing.
check_shape_env_recorded_events = False
# TODO: Perhaps consider allowing unions for the configs below (so you can hit
# multiple reps at the same time)
# Give extended debug information if the string representation of a guard
# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue
# this guard, we will generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_guard_added = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
)
# Give extended debug information when a particular symbol is allocated. For
# example, set this to "u2" and whenever we create this symbol, we will
# generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_create_symbol = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
)
# Give extended debug information (C++ backtrace) for all extended debug
# settings as well as errors. The C++ backtrace is slow and very spammy so we
# don't include it by default even when you're requesting extended debug.
# [@compile_ignored: debug]
extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
# Give extended debug information (line of code) when a torch function
# is called during export. This is useful for showing progress and detecting
# where export might be stuck. Currently only works for strict=False.
# [@compile_ignored: debug]
extended_debug_current_loc = (
os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
)
# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False
# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
validate_shape_env_version_key = False
# If we produce more than this many guards on a symbol, force the symbol to
# get specialized and bail out if this many guards mention this particular
# symbol. This may be slightly more aggressive than the true number of guards
# issued (as we test if we've hit the limit on-the-fly, whereas we may
# do further simplifications at final guard issuance time that make guards
# irrelevant.)
symbol_guard_limit_before_specialize: Optional[int] = None
# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
use_duck_shape = True
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])