mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support functionalization on torch.cond (#89966)
This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following: - Output of each branch aliasing input - In-place mutation on inputs given to each branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
d1123c94a7
commit
76a3869fc6
@ -1,14 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
|
||||
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
|
||||
from torch._ops import PyOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_isolated_graphmodule,
|
||||
get_proxy_slot,
|
||||
ProxyTorchDispatchMode,
|
||||
make_fx,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
@ -19,6 +23,11 @@ from torch.utils._python_dispatch import (
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedAliasMutationException(RuntimeError):
|
||||
reason: str
|
||||
|
||||
|
||||
"""
|
||||
We're going to define a `cond` operation.
|
||||
In order to do this, we need implementations for each of the dispatch keys.
|
||||
@ -149,6 +158,100 @@ def cond_python_dispatcher(*args):
|
||||
return cond(*args)
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, fake_inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with fake inputs and check if
|
||||
producing graph has mutable op on the input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*fake_inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
input_nodes = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _has_potential_branch_input_alias(branch, fake_inputs):
|
||||
"""
|
||||
Dispatch-trace the branch with fake inputs and check if
|
||||
producing graph has output aliasing the branch input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = make_fx(branch)(*fake_inputs)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
input_storages = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
|
||||
outs, _ = pytree.tree_flatten(gm(*fake_inputs))
|
||||
for out in outs:
|
||||
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
||||
"""
|
||||
Functionalization implementation for torch.cond. Currently:
|
||||
1. We don't allow any input mutation inside the branches
|
||||
2. Our check for above condition is not exhaustive
|
||||
"""
|
||||
reapply_views = interpreter.functionalize_add_back_views()
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
|
||||
|
||||
functional_true_fn = functionalize(true_fn, remove=mode)
|
||||
functional_false_fn = functionalize(false_fn, remove=mode)
|
||||
|
||||
with interpreter.lower():
|
||||
fake_tensor_mode = FakeTensorMode()
|
||||
with fake_tensor_mode as ft_mode:
|
||||
for branch in [functional_true_fn, functional_false_fn]:
|
||||
def convert(x):
|
||||
return ft_mode.fake_tensor_converter(ft_mode, x)
|
||||
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
|
||||
if _has_potential_branch_input_mutation(branch, fake_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
for branch in [true_fn, false_fn]:
|
||||
def convert(x):
|
||||
return ft_mode.fake_tensor_converter(ft_mode, x)
|
||||
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
|
||||
if _has_potential_branch_input_alias(branch, fake_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
|
||||
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
|
||||
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
|
||||
|
||||
# TODO(voz): Make this automatic for keys, this is very ugly atm
|
||||
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
cond.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
|
@ -1,2 +1,2 @@
|
||||
from ._map import map # noqa: F401
|
||||
from ._cond import cond # noqa: F401
|
||||
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401
|
||||
|
@ -2,6 +2,8 @@
|
||||
import torch
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond
|
||||
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
||||
from functorch.experimental import functionalize
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
@ -73,6 +75,177 @@ class TestControlFlowTraced(TestCase):
|
||||
|
||||
self.assertEqual(result_false_true, torch.cos(x))
|
||||
|
||||
def test_cond_functionalized(self):
|
||||
def true_fn(x):
|
||||
y = x.sin()
|
||||
y.add_(4)
|
||||
return x.sin().max() + y.sum()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos().min()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
all_ops_in_true_branch = []
|
||||
for node in graph_module.true_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
all_ops_in_true_branch.append(node.target)
|
||||
|
||||
self.assertFalse(any([op._schema.is_mutable for op in all_ops_in_true_branch]))
|
||||
|
||||
def test_cond_functionalized_nested(self):
|
||||
def true_true_fn(x):
|
||||
y = x.cos()
|
||||
y.add_(4)
|
||||
return x.sin().max() + y.sin().max()
|
||||
|
||||
def true_false_fn(x):
|
||||
return x.cos().min()
|
||||
|
||||
def true_fn(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_true_fn, true_false_fn, [x])
|
||||
|
||||
def false_fn(x):
|
||||
return x.sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
||||
|
||||
all_ops = []
|
||||
for node in gm_true_true_branch.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
all_ops.append(node.target)
|
||||
|
||||
self.assertFalse(any([op._schema.is_mutable for op in all_ops]))
|
||||
|
||||
def test_cond_functionalized_data_dependent_pred(self):
|
||||
def true_fn(x):
|
||||
return x.sin().sum()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos().sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.nonzero().shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||
def true_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
view_x.add_(1)
|
||||
return view_x.sin().sum()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos().sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = functionalize(f)
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(functionalize(f))(*example_inputs)
|
||||
|
||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||
def true_fn(x):
|
||||
return x.sin().sum()
|
||||
|
||||
def false_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
view_x.add_(1)
|
||||
return view_x.cos().sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
functional_f = functionalize(f)
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(functionalize(f))(*example_inputs)
|
||||
|
||||
def test_cond_functionalized_output_alias_input(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
|
||||
def false_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
return view_x
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
functional_f = functionalize(f)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
||||
make_fx(functionalize(f))(*example_inputs)
|
||||
|
||||
def test_cond_functionalized_nested_input_mutation(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
return x.sin().max()
|
||||
|
||||
def true_false_fn(x):
|
||||
return x.cos().min()
|
||||
|
||||
def true_fn(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_true_fn, true_false_fn, [x])
|
||||
|
||||
def false_fn(x):
|
||||
return x.sum()
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = functionalize(f)
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
||||
make_fx(functionalize(f))(*example_inputs)
|
||||
|
||||
def test_cond_nested_traced_other_inputs(self):
|
||||
def true_nested(y):
|
||||
return y * y
|
||||
|
@ -48,6 +48,12 @@ class CJvpInterpreterPtr:
|
||||
def lift(self, Tensor) -> Tensor: ...
|
||||
def prevFwdGradMode(self) -> bool: ...
|
||||
|
||||
class CFunctionalizeInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def key(self) -> TransformType: ...
|
||||
def level(self) -> int: ...
|
||||
def functionalizeAddBackViews(self) -> bool: ...
|
||||
|
||||
class CVmapInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def key(self) -> TransformType: ...
|
||||
|
@ -8,6 +8,7 @@ from torch._C._functorch import (
|
||||
RandomnessType,
|
||||
CInterpreter,
|
||||
CGradInterpreterPtr,
|
||||
CFunctionalizeInterpreterPtr,
|
||||
CVmapInterpreterPtr,
|
||||
CJvpInterpreterPtr,
|
||||
pop_dynamic_layer_stack,
|
||||
@ -172,6 +173,20 @@ class JvpInterpreter(FuncTorchInterpreter):
|
||||
return self._cptr.prevFwdGradMode()
|
||||
|
||||
|
||||
class FunctionalizeInterpreter(FuncTorchInterpreter):
|
||||
def __init__(self, cdata: CInterpreter):
|
||||
assert cdata.key() == TransformType.Functionalize
|
||||
self._cdata = cdata
|
||||
self._cptr = CFunctionalizeInterpreterPtr(cdata)
|
||||
|
||||
def process(self, op, args, kwargs):
|
||||
kernel = op.functorch_table[TransformType.Functionalize]
|
||||
return kernel(self, *args, **kwargs)
|
||||
|
||||
def functionalize_add_back_views(self):
|
||||
return self._cptr.functionalizeAddBackViews()
|
||||
|
||||
|
||||
def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
|
||||
key = cinterpreter.key()
|
||||
if key == TransformType.Grad:
|
||||
@ -180,6 +195,8 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
|
||||
return VmapInterpreter(cinterpreter)
|
||||
if key == TransformType.Jvp:
|
||||
return JvpInterpreter(cinterpreter)
|
||||
if key == TransformType.Functionalize:
|
||||
return FunctionalizeInterpreter(cinterpreter)
|
||||
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
|
||||
|
||||
|
||||
|
@ -519,6 +519,13 @@ void initFuncTorchBindings(PyObject* module) {
|
||||
.def("level", &VmapInterpreterPtr::level)
|
||||
.def("batchSize", &VmapInterpreterPtr::batchSize)
|
||||
.def("randomness", &VmapInterpreterPtr::randomness);
|
||||
py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
|
||||
.def(py::init<const Interpreter*>())
|
||||
.def("key", &FunctionalizeInterpreterPtr::key)
|
||||
.def("level", &FunctionalizeInterpreterPtr::level)
|
||||
.def(
|
||||
"functionalizeAddBackViews",
|
||||
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
Reference in New Issue
Block a user