mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Dynamo][Logging] Add sources/types to LazyVariableTracker logging (#165402)
Fixes #162860 This task add the variable source attrition to LazyVariableTracker when output trace bytecode Test plan -- test/dynamo/test_error_messages.py ErrorMessagesTest.test_variable_tracker_source_attribution The output is as specified in the prior mentioned Github issue. <img width="961" height="59" alt="Screenshot 2025-10-13 at 10 19 44 PM" src="https://github.com/user-attachments/assets/fb27da3f-d00b-437b-bf2e-52e892572cd7" /> This is specifically for the log setup with ``TORCH_LOGS=trace_bytecode`` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165402 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42 Co-authored-by: William Wen <williamwen@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b54e466fd0
commit
568d2f3ae7
@ -113,7 +113,7 @@ sort with non-constant keys
|
||||
Explanation: Cannot perform sort with non-constant key. First non-constant key type: <class 'torch.Tensor'>. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.
|
||||
Hint: Use something else as the key.
|
||||
|
||||
Developer debug context: TensorVariable()
|
||||
Developer debug context: LazyVariableTracker(realized: TensorVariable())
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html
|
||||
|
||||
@ -216,7 +216,7 @@ Unsupported context manager
|
||||
Hint: If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then it may be the case that it was created outside the compiled region, which Dynamo does not support. Supported context managers can cross graph break boundaries only if they are local non-closure variables, or are intermediate values.
|
||||
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
|
||||
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on ConstantVariable(int: 3)
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on LazyVariableTracker(realized: ConstantVariable(int: 3))
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html
|
||||
|
||||
@ -543,7 +543,7 @@ Dynamic slicing with Tensor arguments
|
||||
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None)
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
|
||||
|
||||
@ -869,6 +869,51 @@ from user code:
|
||||
if x.sum() > 0:""",
|
||||
)
|
||||
|
||||
# Test that the bytecode source attribution is correct with VariableTracker
|
||||
@make_logging_test(trace_bytecode=True)
|
||||
def test_variable_tracker_source_attribution(self, records):
|
||||
def inner(x):
|
||||
return x + 1
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
x = inner(x)
|
||||
return inner(x)
|
||||
|
||||
fn(torch.ones(3))
|
||||
|
||||
def find_trace_bytecode_lines(long_string):
|
||||
# Split the string into lines
|
||||
lines = long_string.split("\n")
|
||||
# More comprehensive pattern to capture LazyVariableTracker info
|
||||
pattern = r"LazyVariableTracker\([^)]*\)"
|
||||
# Find all lines containing the pattern
|
||||
result = [line for line in lines if re.search(pattern, line)]
|
||||
return result
|
||||
|
||||
# Get all log messages, not just the last one
|
||||
all_messages = []
|
||||
for record in records:
|
||||
msg = munge_exc(record.getMessage(), skip=0)
|
||||
|
||||
all_messages.append(msg)
|
||||
|
||||
# Combine all messages to search through
|
||||
combined_msg = "\n".join(all_messages)
|
||||
all_lines = find_trace_bytecode_lines(combined_msg)
|
||||
|
||||
# For now, just check that we found some lines with LazyVariableTracker
|
||||
self.assertGreater(
|
||||
len(all_lines), 0, "Should find at least one LazyVariableTracker line"
|
||||
)
|
||||
|
||||
self.assertIn(
|
||||
"LazyVariableTracker(unrealized: <class 'function'>)", all_lines[0]
|
||||
)
|
||||
self.assertIn(
|
||||
"LazyVariableTracker(realized: UserFunctionVariable())", all_lines[3]
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_data_dependent_branching_gb(self, records):
|
||||
def fn(x):
|
||||
@ -1141,17 +1186,17 @@ NOTE: the most recent `torch.compile` tracing attempt might not be where you app
|
||||
Most recent bytecode instructions traced (max 20):
|
||||
TRACE RESUME 0 []
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_CONST 1 [LazyVariableTracker()]
|
||||
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
|
||||
TRACE LOAD_CONST 1 [LazyVariableTracker(unrealized: <class 'torch.Tensor'>)]
|
||||
TRACE BINARY_OP 0 [LazyVariableTracker(unrealized: <class 'torch.Tensor'>), ConstantVariable(int: 1)]
|
||||
TRACE STORE_FAST 'y' [TensorVariable()]
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_FAST 'y' [TensorVariable()]
|
||||
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
|
||||
TRACE STORE_FAST 'z' [TensorVariable()]
|
||||
TRACE LOAD_GLOBAL 'torch' []
|
||||
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
|
||||
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
|
||||
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
|
||||
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker(unrealized: <class 'module'>)]
|
||||
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker(unrealized: <class 'module'>)]
|
||||
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
|
@ -1357,7 +1357,7 @@ class InstructionTranslatorBase(
|
||||
|
||||
if self.is_trace_bytecode_log_enabled:
|
||||
trace_bytecode_log.debug(
|
||||
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
|
||||
"TRACE %s %s %s", inst.opname, inst.argval, repr(self.stack)
|
||||
)
|
||||
|
||||
# Store the latest 20 bytecode execution for the process,
|
||||
|
@ -104,9 +104,13 @@ class LazyVariableTracker(VariableTracker):
|
||||
self._cache.name_hint = name
|
||||
|
||||
def __str__(self) -> str:
|
||||
variable_info = "LazyVariableTracker("
|
||||
if self.is_realized():
|
||||
return repr(self.unwrap())
|
||||
return super().__repr__()
|
||||
variable_info += f"realized: {repr(self.unwrap())})"
|
||||
else:
|
||||
variable_info += f"unrealized: {self.peek_type()})"
|
||||
|
||||
return variable_info
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
return getattr(self.realize(), item)
|
||||
|
Reference in New Issue
Block a user