Compare commits

...

12 Commits

Author SHA1 Message Date
9e1c2dcaca Remove previous commented code 2025-10-13 15:08:12 -07:00
dda8af499f Fix lint 2025-10-13 13:16:18 -07:00
d724462021 Fix lint 2025-10-13 12:54:05 -07:00
acfa941045 Fix the repr bytecode log and other existing tests broken 2025-10-13 11:57:15 -07:00
3f3dee2517 Change the variable source impl into the repr in Lazy 2025-10-13 10:57:50 -07:00
8dd565c7de Update to remove unnecessary comment 2025-10-09 10:55:19 -07:00
2a48ab697b Update to follow the naming guideline and other nits 2025-10-09 10:50:43 -07:00
5f12243673 Fix linter 2025-10-08 16:39:45 -07:00
7ca51cbff0 Update to the latest output expectation 2025-10-08 16:07:22 -07:00
4995111108 Add variable tracke attrition to lazyVariableTracker and realized item to realized lazy
ghstack-source-id: 3f63f33bc24c967742fd9a08bf85fbbe2821f15e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164612
2025-10-03 14:57:01 -07:00
1bbe2ae477 Add LazyVariableTracker source attrition
ghstack-source-id: 601ce34aeae192db155960dfe4013c79b34f8d01
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164611
2025-10-03 14:56:57 -07:00
dc75389b69 [Dynamo] Variable tracker attrition testing
ghstack-source-id: 0e8190dfebb36209eb02f68d98cda28066339dbc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164610
2025-10-03 14:56:54 -07:00
3 changed files with 54 additions and 6 deletions

View File

@ -113,7 +113,7 @@ sort with non-constant keys
Explanation: Cannot perform sort with non-constant key. First non-constant key type: <class 'torch.Tensor'>. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.
Hint: Use something else as the key.
Developer debug context: TensorVariable()
Developer debug context: LazyVariableTracker(realized:<class 'method'>, TensorVariable())
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html
@ -216,7 +216,7 @@ Unsupported context manager
Hint: If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then it may be the case that it was created outside the compiled region, which Dynamo does not support. Supported context managers can cross graph break boundaries only if they are local non-closure variables, or are intermediate values.
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on ConstantVariable(int: 3)
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on LazyVariableTracker(realized:<class 'method'>, ConstantVariable(int: 3))
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html
@ -543,7 +543,7 @@ Dynamic slicing with Tensor arguments
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None)
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized:<class 'method'>, TensorVariable()), step: ConstantVariable(NoneType: None)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
@ -869,6 +869,48 @@ from user code:
if x.sum() > 0:""",
)
# Test that the bytecode source attribution is correct with VariableTracker
@make_logging_test(trace_bytecode=True)
def test_variable_tracker_source_attribution(self, records):
def inner(x):
return x + 1
@torch.compile(backend="eager")
def fn(x):
x = inner(x)
return inner(x)
fn(torch.ones(3))
def find_trace_bytecode_lines(long_string):
# Split the string into lines
lines = long_string.split("\n")
# More comprehensive pattern to capture LazyVariableTracker info
pattern = r"LazyVariableTracker\([^)]*\)"
# Find all lines containing the pattern
result = [line for line in lines if re.search(pattern, line)]
return result
# Get all log messages, not just the last one
all_messages = []
for record in records:
msg = munge_exc(record.getMessage(), skip=0)
all_messages.append(msg)
# Combine all messages to search through
combined_msg = "\n".join(all_messages)
all_lines = find_trace_bytecode_lines(combined_msg)
# For now, just check that we found some lines with LazyVariableTracker
self.assertGreater(
len(all_lines), 0, "Should find at least one LazyVariableTracker line"
)
for line in all_lines:
self.assertIn("realized", line)
self.assertIn("class", line)
@make_logging_test(graph_breaks=True)
def test_data_dependent_branching_gb(self, records):
def fn(x):

View File

@ -1338,7 +1338,7 @@ class InstructionTranslatorBase(
if self.is_trace_bytecode_log_enabled:
trace_bytecode_log.debug(
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
"TRACE %s %s %s", inst.opname, inst.argval, repr(self.stack)
)
self.update_block_stack(inst)

View File

@ -104,9 +104,15 @@ class LazyVariableTracker(VariableTracker):
self._cache.name_hint = name
def __str__(self) -> str:
variable_info = "LazyVariableTracker("
if self.is_realized():
return repr(self.unwrap())
return super().__repr__()
variable_info += (
f"realized:{type(self.original_value)}, {repr(self.unwrap())})"
)
else:
variable_info += f"Unrealized: {self.peek_type()})"
return variable_info
def __getattr__(self, item: str) -> Any:
return getattr(self.realize(), item)