Compare commits

...

2 Commits

Author SHA1 Message Date
09f003addd Update on "[Requires discussion][dynamo] Overwrite proxy of subgraph outputs"
This ensures that
1) We use the same VaribleTracker for the same fake tensor value. Earlier, one
   fake value could be mapped to multiple VTs.
2) If we are trying to again construct a VT for the same fake tensor, we
   overwrite the proxy of the TensorVariable.

This ensures that if a TensorVariable was mutated in speculate_subgraph,
we make it available from the subgraph output.

This DOES not solve the overall problem of tensors missed in the side
effect. It STILL relies on the user to make the side-effected return
from the subgraph.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-04 16:28:55 -08:00
d8c2d98b41 [dynamo] Overwrite proxy of subgraph outputs
This ensures that
1) We use the same VaribleTracker for a fake tensor value. Earlier, one
   fake value could be mapped to multiple VTs.
2) If we are trying to construct a VT for the same fake tensor, we
   overwrite the proxy of the TensorVariable.

This ensures that if a TensorVariable was mutated in speculate_subgraph,
we make it available from the subgraph output.

This DOES not solve the overall problem of tensors missed in the side
effect. It STILL relies on the user to make the side-effected return
from the subgraph.

[ghstack-poisoned]
2025-11-04 16:20:32 -08:00
3 changed files with 58 additions and 4 deletions

View File

@ -1672,6 +1672,33 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation(self):
def gn(x, z):
out = x.sin()
z.append(out)
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
def fn(x):
z = []
out1, out2 = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
return out1, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
devices = ["cuda", "hpu"]
instantiate_device_type_tests(

View File

@ -736,6 +736,8 @@ class OutputGraph(OutputGraphCommon):
# dynamo_flat_name_to_original_fqn mapping.
self.used_inlined_inbuilt_modules_names: OrderedSet[str] = OrderedSet()
self.cached_tensor_vts: dict[int, VariableTracker] = {}
def mark_bytecode_tracing_start(self) -> None:
self.compiler_trace_stack.enter_context(
dynamo_timed(
@ -2994,6 +2996,7 @@ class SubgraphTracer(fx.Tracer):
raise RuntimeError(
"Inference mode is supposed to be disabled during compilation. Please open an issue."
)
# self.intermediate_tensor_vts: OrderedSet[VariableTracker] = OrderedSet()
# preserve original meta if it is available
def _maybe_preserve_original_meta(
@ -3018,6 +3021,9 @@ class SubgraphTracer(fx.Tracer):
if "stack_trace" in meta:
node.meta["stack_trace"] = meta["stack_trace"]
# def record_tensor_vt(self, vt: VariableTracker) -> None:
# self.intermediate_tensor_vts.add(vt)
def create_proxy(
self,
kind: str,

View File

@ -2785,21 +2785,27 @@ def wrap_fx_proxy_cls(
target_cls, tx, proxy, example_value=None, subclass_type=None, **options
):
if example_value is None:
return _wrap_fx_proxy(
out = _wrap_fx_proxy(
target_cls, tx, proxy, example_value, subclass_type, **options
)
elif isinstance(example_value, torch.Tensor):
return _wrap_fx_preexisting_tensor(
out = _wrap_fx_preexisting_tensor(
target_cls, tx, proxy, example_value, subclass_type, **options
)
else:
# This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported
# data structures. In essence this just handles tracing some other value which may
# contain Fake Tensors or is otherwise proxyable.
return handle_traced_output(
out = handle_traced_output(
example_value, tx, proxy, options, subclass_type, target_cls
)
# if isinstance(out, torch._dynamo.variables.TensorVariable):
# tx.output.current_tracer.record_tensor_vt(out)
# print("Adding new VT", out, out.proxy, id(out))
return out
# This is 1 above (wrapping a preexisting tensor)
def _wrap_fx_preexisting_tensor(
@ -3159,9 +3165,22 @@ def construct_tensor_variable(
Actually construct a tensor variable after all the pre-processing from
wrapping a pre-existing or newly created tensor value.
"""
cached_tensor_vts = tx.output.cached_tensor_vts
if id(example_value) in cached_tensor_vts and proxy.node.op != "placeholder":
saved_vt = cached_tensor_vts[id(example_value)]
# Overwrite the existing proxy so that from now on, this VT always uses
# the new proxy
saved_vt.proxy = proxy
# Not sure why we need this
example_value = _clone_input(example_value, tx.fake_mode)
set_example_value(proxy.node, example_value)
return saved_vt
# NB: In most (all?) cases, this does not actually do a clone.
# (WARNING: this means that if we mutate metadata on the fake
# tensor, the stored example value will update too!)
old_example_value = example_value
example_value = _clone_input(example_value, tx.fake_mode)
set_example_value(proxy.node, example_value)
# We bind the unbacked symints in sizes/trdies of tensor lazily.
@ -3171,7 +3190,9 @@ def construct_tensor_variable(
if proxy.node.op != "placeholder":
tx.output.current_tracer.track_produced_symints(example_value, proxy)
options.update(get_specialized_props(target_cls, tx, example_value, subclass_type))
return target_cls(proxy, **options)
out = target_cls(proxy, **options)
cached_tensor_vts[id(old_example_value)] = out
return out
def get_automatic_dynamic_shapes_mark_as():