mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[cond] make cond re-dispatch in proxy mode (#146954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146954 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
67cbbb29e0
commit
2ce6de2415
@ -339,6 +339,17 @@ def unmask_none_gradients(grads, operands):
|
||||
return unmasked_grads
|
||||
|
||||
|
||||
def _maybe_fake_prop_ignore_unbacked(fn, args):
|
||||
with ExitStack() as ctx_stack:
|
||||
if (fake_mode := detect_fake_mode(args)) is not None:
|
||||
ctx_stack.enter_context(fake_mode)
|
||||
if fake_mode.shape_env is not None:
|
||||
ctx_stack.enter_context(
|
||||
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
)
|
||||
return fn(*args)
|
||||
|
||||
|
||||
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
||||
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
||||
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
||||
@ -578,18 +589,12 @@ def check_input_alias_and_mutation(
|
||||
# We need to temporarily turn inference_mode off because
|
||||
# under inference mode, tensor version counter is not tracked.
|
||||
ctx_stack.enter_context(torch.inference_mode(False))
|
||||
if (fake_mode := detect_fake_mode(fake_args)) is not None:
|
||||
ctx_stack.enter_context(fake_mode)
|
||||
if fake_mode.shape_env is not None:
|
||||
ctx_stack.enter_context(
|
||||
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
)
|
||||
cloned = [
|
||||
clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg
|
||||
for arg in fake_args
|
||||
]
|
||||
before = [_tensor_version(arg) for arg in cloned]
|
||||
outputs = gm(*cloned)
|
||||
outputs = _maybe_fake_prop_ignore_unbacked(gm, cloned)
|
||||
outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
|
||||
after = [_tensor_version(arg) for arg in cloned]
|
||||
mutated_inputs = [
|
||||
|
Reference in New Issue
Block a user