[cond] make cond re-dispatch in proxy mode (#146954)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146954
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2025-02-11 16:56:03 -08:00
committed by PyTorch MergeBot
parent 67cbbb29e0
commit 2ce6de2415
2 changed files with 20 additions and 34 deletions

View File

@ -339,6 +339,17 @@ def unmask_none_gradients(grads, operands):
return unmasked_grads
def _maybe_fake_prop_ignore_unbacked(fn, args):
with ExitStack() as ctx_stack:
if (fake_mode := detect_fake_mode(args)) is not None:
ctx_stack.enter_context(fake_mode)
if fake_mode.shape_env is not None:
ctx_stack.enter_context(
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
return fn(*args)
# TODO: The parameter use_output_and_grad_bw is required because some operations
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
@ -578,18 +589,12 @@ def check_input_alias_and_mutation(
# We need to temporarily turn inference_mode off because
# under inference mode, tensor version counter is not tracked.
ctx_stack.enter_context(torch.inference_mode(False))
if (fake_mode := detect_fake_mode(fake_args)) is not None:
ctx_stack.enter_context(fake_mode)
if fake_mode.shape_env is not None:
ctx_stack.enter_context(
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
cloned = [
clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg
for arg in fake_args
]
before = [_tensor_version(arg) for arg in cloned]
outputs = gm(*cloned)
outputs = _maybe_fake_prop_ignore_unbacked(gm, cloned)
outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
after = [_tensor_version(arg) for arg in cloned]
mutated_inputs = [