mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove fake inputs from control flow (#95988)
Previously running make_fx with tracing_mode="symbolic" resulted in `RuntimeError: Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type FakeTensor`. This is probably due to there existing multiple FakeTensorModes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95988 Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a781ce3e1
commit
5a07c3d3d1
@ -157,14 +157,14 @@ def cond_python_dispatcher(*args):
|
||||
return cond(*args)
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, fake_inputs):
|
||||
def _has_potential_branch_input_mutation(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with fake inputs and check if
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has mutable op on the input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*fake_inputs)
|
||||
gm = make_fx(branch)(*inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond is
|
||||
# functionalized
|
||||
@ -185,14 +185,14 @@ def _has_potential_branch_input_mutation(branch, fake_inputs):
|
||||
|
||||
return False
|
||||
|
||||
def _has_potential_branch_input_alias(branch, fake_inputs):
|
||||
def _has_potential_branch_input_alias(branch, inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with fake inputs and check if
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has output aliasing the branch input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*fake_inputs)
|
||||
gm = make_fx(branch)(*inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond is
|
||||
# functionalized
|
||||
@ -204,11 +204,11 @@ def _has_potential_branch_input_alias(branch, fake_inputs):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
|
||||
outs, _ = pytree.tree_flatten(gm(*fake_inputs))
|
||||
for out in outs:
|
||||
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
|
||||
return True
|
||||
if node.op == "output":
|
||||
for out in node.args:
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
if out_storage in input_storages:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -231,22 +231,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
||||
functional_false_fn = functionalize(false_fn, remove=mode)
|
||||
|
||||
with interpreter.lower():
|
||||
fake_tensor_mode = FakeTensorMode()
|
||||
with fake_tensor_mode as ft_mode:
|
||||
for branch in [functional_true_fn, functional_false_fn]:
|
||||
def convert(x):
|
||||
return ft_mode.fake_tensor_converter(ft_mode, x)
|
||||
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
|
||||
if _has_potential_branch_input_mutation(branch, fake_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
for branch in [true_fn, false_fn]:
|
||||
def convert(x):
|
||||
return ft_mode.fake_tensor_converter(ft_mode, x)
|
||||
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
|
||||
if _has_potential_branch_input_alias(branch, fake_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
for branch in [functional_true_fn, functional_false_fn]:
|
||||
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
|
||||
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
|
||||
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
|
||||
|
@ -115,29 +115,16 @@ def map_functionalize(interpreter, f, xs, *args):
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
|
||||
with interpreter.lower():
|
||||
fake_tensor_mode = FakeTensorMode()
|
||||
with fake_tensor_mode as ft_mode:
|
||||
inputs = (unwrapped_xs,) + unwrapped_args
|
||||
if _has_potential_branch_input_mutation(functional_map_fn, inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
|
||||
# Returns fake inputs for a single map function call
|
||||
def get_fake_inputs(unwrapped_xs, unwrapped_args):
|
||||
fake_xs = ft_mode.fake_tensor_converter(ft_mode, unwrapped_xs)
|
||||
fake_args = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda x: ft_mode.fake_tensor_converter(ft_mode, x),
|
||||
unwrapped_args,
|
||||
)
|
||||
return (fake_xs[0],) + fake_args
|
||||
|
||||
fake_inputs = get_fake_inputs(unwrapped_xs, unwrapped_args)
|
||||
if _has_potential_branch_input_mutation(functional_map_fn, fake_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(functional_map_fn, fake_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
if _has_potential_branch_input_alias(functional_map_fn, inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
|
||||
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
|
||||
|
Reference in New Issue
Block a user