Support functionalization on torch.cond (#89966)

This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following:

- Output of each branch aliasing input
- In-place mutation on inputs given to each branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966
Approved by: https://github.com/zou3519
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2022-12-21 13:50:36 -08:00
committed by PyTorch MergeBot
parent d1123c94a7
commit 76a3869fc6
6 changed files with 307 additions and 1 deletions

View File

@ -1,14 +1,18 @@
from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import StorageWeakRef
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
get_isolated_graphmodule,
get_proxy_slot,
ProxyTorchDispatchMode,
make_fx,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
@ -19,6 +23,11 @@ from torch.utils._python_dispatch import (
from torch.utils._pytree import tree_flatten
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
@ -149,6 +158,100 @@ def cond_python_dispatcher(*args):
return cond(*args)
def _has_potential_branch_input_mutation(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True
return False
def _has_potential_branch_input_alias(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e
input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
outs, _ = pytree.tree_flatten(gm(*fake_inputs))
for out in outs:
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
return True
return False
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
"""
Functionalization implementation for torch.cond. Currently:
1. We don't allow any input mutation inside the branches
2. Our check for above condition is not exhaustive
"""
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
functional_true_fn = functionalize(true_fn, remove=mode)
functional_false_fn = functionalize(false_fn, remove=mode)
with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
for branch in [functional_true_fn, functional_false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_mutation(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
for branch in [true_fn, false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_alias(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)

View File

@ -1,2 +1,2 @@
from ._map import map # noqa: F401
from ._cond import cond # noqa: F401
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401

View File

@ -2,6 +2,8 @@
import torch
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
from functorch.experimental.control_flow import UnsupportedAliasMutationException
from functorch.experimental import functionalize
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@ -73,6 +75,177 @@ class TestControlFlowTraced(TestCase):
self.assertEqual(result_false_true, torch.cos(x))
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
y.add_(4)
return x.sin().max() + y.sum()
def false_fn(x):
return x.cos().min()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops_in_true_branch = []
for node in graph_module.true_graph_0.graph.nodes:
if node.op == "call_function":
all_ops_in_true_branch.append(node.target)
self.assertFalse(any([op._schema.is_mutable for op in all_ops_in_true_branch]))
def test_cond_functionalized_nested(self):
def true_true_fn(x):
y = x.cos()
y.add_(4)
return x.sin().max() + y.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
all_ops = []
for node in gm_true_true_branch.graph.nodes:
if node.op == "call_function":
all_ops.append(node.target)
self.assertFalse(any([op._schema.is_mutable for op in all_ops]))
def test_cond_functionalized_data_dependent_pred(self):
def true_fn(x):
return x.sin().sum()
def false_fn(x):
return x.cos().sum()
def f(x):
pred = x.nonzero().shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
def test_cond_functionalized_input_mutation_on_true_branch(self):
def true_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.sin().sum()
def false_fn(x):
return x.cos().sum()
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)
def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x):
return x.sin().sum()
def false_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.cos().sum()
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)
def test_cond_functionalized_output_alias_input(self):
def true_fn(x):
return x
def false_fn(x):
view_x = x.view(x.shape)
return view_x
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
functional_f(*example_inputs)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
make_fx(functionalize(f))(*example_inputs)
def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x):
x.add_(4)
return x.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)
def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y

View File

@ -48,6 +48,12 @@ class CJvpInterpreterPtr:
def lift(self, Tensor) -> Tensor: ...
def prevFwdGradMode(self) -> bool: ...
class CFunctionalizeInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def functionalizeAddBackViews(self) -> bool: ...
class CVmapInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def key(self) -> TransformType: ...

View File

@ -8,6 +8,7 @@ from torch._C._functorch import (
RandomnessType,
CInterpreter,
CGradInterpreterPtr,
CFunctionalizeInterpreterPtr,
CVmapInterpreterPtr,
CJvpInterpreterPtr,
pop_dynamic_layer_stack,
@ -172,6 +173,20 @@ class JvpInterpreter(FuncTorchInterpreter):
return self._cptr.prevFwdGradMode()
class FunctionalizeInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Functionalize
self._cdata = cdata
self._cptr = CFunctionalizeInterpreterPtr(cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Functionalize]
return kernel(self, *args, **kwargs)
def functionalize_add_back_views(self):
return self._cptr.functionalizeAddBackViews()
def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
key = cinterpreter.key()
if key == TransformType.Grad:
@ -180,6 +195,8 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
return VmapInterpreter(cinterpreter)
if key == TransformType.Jvp:
return JvpInterpreter(cinterpreter)
if key == TransformType.Functionalize:
return FunctionalizeInterpreter(cinterpreter)
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")

View File

@ -519,6 +519,13 @@ void initFuncTorchBindings(PyObject* module) {
.def("level", &VmapInterpreterPtr::level)
.def("batchSize", &VmapInterpreterPtr::batchSize)
.def("randomness", &VmapInterpreterPtr::randomness);
py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
.def(py::init<const Interpreter*>())
.def("key", &FunctionalizeInterpreterPtr::key)
.def("level", &FunctionalizeInterpreterPtr::level)
.def(
"functionalizeAddBackViews",
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);
}
} // namespace impl