Remove fake inputs from control flow (#95988)

Previously running make_fx with tracing_mode="symbolic" resulted in `RuntimeError: Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type FakeTensor`. This is probably due to there existing multiple FakeTensorModes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95988
Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
This commit is contained in:
Angela Yi
2023-03-04 00:58:49 +00:00
committed by PyTorch MergeBot
parent 9a781ce3e1
commit 5a07c3d3d1
3 changed files with 43 additions and 49 deletions

View File

@ -157,14 +157,14 @@ def cond_python_dispatcher(*args):
return cond(*args)
def _has_potential_branch_input_mutation(branch, fake_inputs):
def _has_potential_branch_input_mutation(branch, inputs):
"""
Dispatch-trace the branch with fake inputs and check if
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
@ -185,14 +185,14 @@ def _has_potential_branch_input_mutation(branch, fake_inputs):
return False
def _has_potential_branch_input_alias(branch, fake_inputs):
def _has_potential_branch_input_alias(branch, inputs):
"""
Dispatch-trace the branch with fake inputs and check if
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
gm = make_fx(branch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
@ -204,11 +204,11 @@ def _has_potential_branch_input_alias(branch, fake_inputs):
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
outs, _ = pytree.tree_flatten(gm(*fake_inputs))
for out in outs:
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
return True
if node.op == "output":
for out in node.args:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
if out_storage in input_storages:
return True
return False
@ -231,22 +231,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
functional_false_fn = functionalize(false_fn, remove=mode)
with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
for branch in [functional_true_fn, functional_false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_mutation(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
for branch in [true_fn, false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_alias(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
for branch in [functional_true_fn, functional_false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())

View File

@ -115,29 +115,16 @@ def map_functionalize(interpreter, f, xs, *args):
functional_map_fn = functionalize(f, remove=mode)
with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
inputs = (unwrapped_xs,) + unwrapped_args
if _has_potential_branch_input_mutation(functional_map_fn, inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
# Returns fake inputs for a single map function call
def get_fake_inputs(unwrapped_xs, unwrapped_args):
fake_xs = ft_mode.fake_tensor_converter(ft_mode, unwrapped_xs)
fake_args = pytree.tree_map_only(
torch.Tensor,
lambda x: ft_mode.fake_tensor_converter(ft_mode, x),
unwrapped_args,
)
return (fake_xs[0],) + fake_args
fake_inputs = get_fake_inputs(unwrapped_xs, unwrapped_args)
if _has_potential_branch_input_mutation(functional_map_fn, fake_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(functional_map_fn, fake_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
if _has_potential_branch_input_alias(functional_map_fn, inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())