From 58d2c66a70f2717c3cf7af8fd1abcc1ab100e3de Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 12 Jun 2023 18:45:42 +0000 Subject: [PATCH] [activation checkpointing] Higher order functional rng op wrappers (#102934) Introduces two higher order operators * run_and_save_rng_state - Saves the current rng state and then runs the op. * run_with_rng_state - Runs the op with the rng state supplied as an input Ideally, we would like to use torch.compile for these operators. But currently the plan is to introduce these operators at the partitioner level, obviating the need to support them fully through the torch.compile stack. To ensure that we have good enough debugging with minifiers, we have ensure that they work with make_fx. In future, we can move on torch.compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102934 Approved by: https://github.com/jansel, https://github.com/zou3519 --- test/functorch/test_aotdispatch.py | 1 + test/inductor/test_torchinductor.py | 44 +++++ test/test_prims.py | 21 +++ torch/_inductor/codegen/wrapper.py | 3 + torch/_inductor/ir.py | 34 +++- torch/_inductor/lowering.py | 4 + torch/_inductor/utils.py | 2 + torch/_ops.py | 5 +- torch/_prims/rng_prims.py | 155 ++++++++++++++++++ torch/fx/node.py | 2 +- .../_internal/fx/passes/fx_to_onnxscript.py | 4 +- torch/onnx/_type_utils.py | 1 + 12 files changed, 266 insertions(+), 10 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index c0c888c2cbfa..5c820cd9a321 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2566,6 +2566,7 @@ class TestPartitioning(AOTTestCase): res = aot_mod(x) res.sum().backward() + class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): class MockModule(torch.nn.Module): diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 779ea7a43910..9a3a1d33dfce 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4819,6 +4819,50 @@ class CommonTemplate: self.assertTrue((d >= 0).all()) self.assertTrue((d < 1).all()) + def test_functionalize_rng_wrappers(self): + # Ideally, we would like to use torch.compile for these operators. But + # currently the plan is to introduce these operators at the partitioner + # level, obviating the need to support them fully through the + # torch.compile stack. To ensure that we have good enough debugging with + # minifiers, we have ensure that they work with make_fx. This test uses + # make_fx to do the testing. In future, we can move on torch.compile. + def fn(): + rng_state1, a1 = torch._prims.rng_prims.run_and_save_rng_state( + torch.ops.aten.rand.default, + [4, 4], + dtype=torch.float32, + device=self.device, + ) + rng_state2, a2 = torch._prims.rng_prims.run_and_save_rng_state( + torch.ops.aten.rand.default, + [4, 4], + dtype=torch.float32, + device=self.device, + ) + + b1 = torch._prims.rng_prims.run_with_rng_state( + rng_state1, + torch.ops.aten.rand.default, + [4, 4], + dtype=torch.float32, + device=self.device, + ) + b2 = torch._prims.rng_prims.run_with_rng_state( + rng_state2, + torch.ops.aten.rand.default, + [4, 4], + dtype=torch.float32, + device=self.device, + ) + + return (a1, a2, b1, b2) + + mod = make_fx(fn)() + compiled_f = compile_fx_inner(mod, ()) + a1, a2, b1, b2 = compiled_f(()) + self.assertEqual(a1, b1) + self.assertEqual(a2, b2) + @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_philox_rand(self): if self.device == "cpu": diff --git a/test/test_prims.py b/test/test_prims.py index 14b8a27d7acb..da1f5a101c3d 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -1124,6 +1124,27 @@ class TestPrims(TestCase): for a, b in zip(references, results): self.assertEqual(a, b) + + @dtypes(torch.float32) + def test_functional_rng_wrappers(self, device, dtype): + + torch.manual_seed(123) + ref1 = torch.rand(10, device=device, dtype=dtype) + ref2 = torch.rand(10, device=device, dtype=dtype) + + + torch.manual_seed(123) + rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) + rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) + + res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype) + res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype) + + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + self.assertEqual(ref1, res3) + self.assertEqual(ref2, res4) + class TestPrimsBasic(TestCase): def test_torch_ops(self): r = make_tensor((2,), device='cpu', dtype=torch.float) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8b1742149d27..f7a525346a82 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -14,6 +14,7 @@ from sympy import Expr import torch from torch._dynamo.utils import counters, dynamo_timed from torch.fx.experimental.symbolic_shapes import SymTypes +from torch.fx.node import _get_qualified_name from .. import codecache, config, ir from ..codecache import CudaKernelParamCache from ..utils import ( @@ -734,6 +735,8 @@ class WrapperCodeGen(CodeGen): return self.ref return repr(type(s)(Shim(self.val_to_str(a)) for a in s)) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) else: return repr(s) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 32c376b40ee5..445668fa83d2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3235,6 +3235,13 @@ class FallbackKernel(ExternKernelAlloc): if V.graph.cpp_wrapper else f"aten.{op_overload_packet.__name__}" ) + elif isinstance(kernel, torch._ops.HigherOrderOperator): + if getattr(torch._prims.rng_prims, kernel.__name__, None) is kernel: + self.kernel = f"torch._prims.rng_prims.{kernel.__name__}" + else: + raise NotImplementedError( + "Unable to find HigherOrderOperator kernel name" + ) else: if V.graph.cpp_wrapper: from torch._inductor.codegen.wrapper import ( @@ -3300,6 +3307,24 @@ class FallbackKernel(ExternKernelAlloc): self.kwargs.update(kwargs) return args + @staticmethod + def find_device(tensor_args, example_output): + if tensor_args: + return tensor_args[0].get_device() + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + devices = {FallbackKernel.find_device(None, x) for x in example_output} + # Remove None + devices = [device for device in devices if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if device.type == "cuda": + return device + return devices[0] + return None + def codegen(self, wrapper): if self.use_cpp_op_schema: args = [*self.codegen_args(), *self.codegen_kwargs()] @@ -3337,13 +3362,10 @@ class FallbackKernel(ExternKernelAlloc): unflatten_args, ) = cls.process_kernel(kernel, *args, **kwargs) - assert tensor_args or isinstance( - example_output, torch.Tensor - ), "Not sure where to find device info" + device = FallbackKernel.find_device(tensor_args, example_output) + assert device, "Not sure where to find device info" packed = FallbackKernel( - MultiOutputLayout( - tensor_args[0].get_device() if tensor_args else example_output.device - ), + MultiOutputLayout(device), kernel, tensor_args, non_tensor_args, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index b3a2c6967554..473263489062 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1634,6 +1634,10 @@ make_fallback(aten.gcd.default, warn=False) make_fallback(aten._linalg_eigh) make_fallback(aten.zeros.names) + +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + # fails accuracy on test_torch.py, and explicit fallback required to avoid warn=True on implicit make_fallback(aten.exponential.default, warn=False) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 118c59ce3f2b..4c9777264f03 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -409,6 +409,8 @@ def has_incompatible_cudagraph_ops(gm): "aten._fused_moving_avg_obs_fq_helper_functional.default", "fbgemm.dense_to_jagged.default", "fbgemm.jagged_to_padded_dense.default", + "run_with_rng_state", + "run_and_save_rng_state", } if torch.are_deterministic_algorithms_enabled(): forbidden_set.update( diff --git a/torch/_ops.py b/torch/_ops.py index 16dbe75d1746..b8e7a6c1dd83 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -251,8 +251,11 @@ class HigherOrderOperator(OperatorBase): dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) + def __str__(self): + return f"{self.name()}" + def name(self): - return self.name + return self._name def _to_flat_tuple(args, kwargs): diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 28568811fcbd..d4bcd98b2f80 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,11 +1,27 @@ +from functools import partial from typing import Optional, Tuple import torch +import torch.utils._pytree as pytree from torch import _prims +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for from torch._prims_common.wrappers import backwards_not_supported +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, + unwrap_proxy, +) from torch.types import _device, _dtype +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode, + _pop_mode_temporarily, +) + rngprim_namespace = "rngprims" rngprim = torch.library.Library(rngprim_namespace, "DEF") @@ -132,5 +148,144 @@ def register_philox_rand(): ) +def get_device(args, kwargs): + if kwargs.get("device"): + device = kwargs.get("device") + if isinstance(device, str): + device = torch.device(device) + return device.type + + devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} + if any(dev == "cuda" for dev in devices): + return "cuda" + elif any(dev == "cpu" for dev in devices): + return "cpu" + return None + + +def register_run_and_save_rng_state_op(): + run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state") + + run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView) + run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] + run_and_save_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] + + @run_and_save_rng_state.py_impl(DispatchKey.Autograd) + def impl_autograd(op, *args, **kwargs): + with torch._C._AutoDispatchBelowAutograd(): + return run_and_save_rng_state(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, **kwargs): + return torch.cuda.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(op, *args, **kwargs): + return torch.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, **kwargs): + impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(op, *args, **kwargs): + # Check device to call the right impl + return impl_backend_select(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(op, *args, **kwargs): + mode = _get_current_dispatch_mode() + assert mode is not None + with _pop_mode_temporarily() as mode: + if mode.enable_tracing: + out = impl_fake_tensor_mode(op, *args, **kwargs) + proxy_args = pytree.tree_map(partial(unwrap_proxy, mode), (op, *args)) + proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=mode.tracer + ) + else: + return run_and_save_rng_state(op, *args, **kwargs) + + return run_and_save_rng_state + + +def register_run_with_rng_state_op(): + run_with_rng_state = HigherOrderOperator("run_with_rng_state") + + run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView) + run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] + run_with_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] + + @run_with_rng_state.py_impl(DispatchKey.Autograd) + def impl_autograd(rng_state, op, *args, **kwargs): + with torch._C._AutoDispatchBelowAutograd(): + return run_with_rng_state(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(rng_state, op, *args, **kwargs): + current_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state.cpu()) + out = op(*args, **kwargs) + torch.cuda.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(rng_state, op, *args, **kwargs): + current_state = torch.get_rng_state() + torch.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs): + mode = _get_current_dispatch_mode() + assert mode is not None + with _pop_mode_temporarily() as mode: + if mode.enable_tracing: + with disable_proxy_modes_tracing(): + out = run_with_rng_state(rng_state, op, *args, **kwargs) + proxy_args = pytree.tree_map( + partial(unwrap_proxy, mode), (rng_state, op, *args) + ) + proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=mode.tracer + ) + else: + return run_with_rng_state(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(rng_state, op, *args, **kwargs): + impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(rng_state, op, *args, **kwargs): + # Skip setting the set_rng_state as it does not work well with fake tensors. + # And it does not matter for the fake tensor mode. + return op(*args, **kwargs) + + return run_with_rng_state + + +run_and_save_rng_state = register_run_and_save_rng_state_op() +run_with_rng_state = register_run_with_rng_state_op() + + def register_rng_prims(): register_philox_rand() diff --git a/torch/fx/node.py b/torch/fx/node.py index c11923caf57d..afb7d2917e89 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: __all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout] + torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] Target = Union[Callable[..., Any], str] diff --git a/torch/onnx/_internal/fx/passes/fx_to_onnxscript.py b/torch/onnx/_internal/fx/passes/fx_to_onnxscript.py index 9396e7e57dee..7a73cebb6434 100644 --- a/torch/onnx/_internal/fx/passes/fx_to_onnxscript.py +++ b/torch/onnx/_internal/fx/passes/fx_to_onnxscript.py @@ -352,9 +352,9 @@ def _export_fx_node_to_onnxscript( elif node.op == "call_function": # aten ops and other stateless functions. if node.target == operator.getitem and isinstance( - fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr] + fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index] ): - onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr] + onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] index = node.args[1] output = onnx_tensor_tuple[index] # type: ignore[index] assert ( diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py index c151abb0e28b..345fbe59c17f 100644 --- a/torch/onnx/_type_utils.py +++ b/torch/onnx/_type_utils.py @@ -372,6 +372,7 @@ BaseArgumentTypes = Union[ torch.device, torch.memory_format, torch.layout, + torch._ops.OpOverload, ] Argument = Optional[ Union[