[activation checkpointing] Higher order functional rng op wrappers (#102934)

Introduces two higher order operators
* run_and_save_rng_state - Saves the current rng state and then runs the op.
* run_with_rng_state - Runs the op with the rng state supplied as an input

Ideally, we would like to use torch.compile for these operators. But currently the plan is to introduce these operators at the partitioner level, obviating the need to support them fully through the torch.compile stack. To ensure that we have good enough debugging with minifiers, we have ensure that they work with make_fx. In future, we can move on torch.compile.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102934
Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
Animesh Jain
2023-06-12 18:45:42 +00:00
committed by PyTorch MergeBot
parent 31ee1512d3
commit 58d2c66a70
12 changed files with 266 additions and 10 deletions

View File

@ -2566,6 +2566,7 @@ class TestPartitioning(AOTTestCase):
res = aot_mod(x)
res.sum().backward()
class TestAOTModuleSimplified(AOTTestCase):
def test_aot_module_simplified(self):
class MockModule(torch.nn.Module):

View File

@ -4819,6 +4819,50 @@ class CommonTemplate:
self.assertTrue((d >= 0).all())
self.assertTrue((d < 1).all())
def test_functionalize_rng_wrappers(self):
# Ideally, we would like to use torch.compile for these operators. But
# currently the plan is to introduce these operators at the partitioner
# level, obviating the need to support them fully through the
# torch.compile stack. To ensure that we have good enough debugging with
# minifiers, we have ensure that they work with make_fx. This test uses
# make_fx to do the testing. In future, we can move on torch.compile.
def fn():
rng_state1, a1 = torch._prims.rng_prims.run_and_save_rng_state(
torch.ops.aten.rand.default,
[4, 4],
dtype=torch.float32,
device=self.device,
)
rng_state2, a2 = torch._prims.rng_prims.run_and_save_rng_state(
torch.ops.aten.rand.default,
[4, 4],
dtype=torch.float32,
device=self.device,
)
b1 = torch._prims.rng_prims.run_with_rng_state(
rng_state1,
torch.ops.aten.rand.default,
[4, 4],
dtype=torch.float32,
device=self.device,
)
b2 = torch._prims.rng_prims.run_with_rng_state(
rng_state2,
torch.ops.aten.rand.default,
[4, 4],
dtype=torch.float32,
device=self.device,
)
return (a1, a2, b1, b2)
mod = make_fx(fn)()
compiled_f = compile_fx_inner(mod, ())
a1, a2, b1, b2 = compiled_f(())
self.assertEqual(a1, b1)
self.assertEqual(a2, b2)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_philox_rand(self):
if self.device == "cpu":

View File

@ -1124,6 +1124,27 @@ class TestPrims(TestCase):
for a, b in zip(references, results):
self.assertEqual(a, b)
@dtypes(torch.float32)
def test_functional_rng_wrappers(self, device, dtype):
torch.manual_seed(123)
ref1 = torch.rand(10, device=device, dtype=dtype)
ref2 = torch.rand(10, device=device, dtype=dtype)
torch.manual_seed(123)
rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
self.assertEqual(ref1, res3)
self.assertEqual(ref2, res4)
class TestPrimsBasic(TestCase):
def test_torch_ops(self):
r = make_tensor((2,), device='cpu', dtype=torch.float)

View File

@ -14,6 +14,7 @@ from sympy import Expr
import torch
from torch._dynamo.utils import counters, dynamo_timed
from torch.fx.experimental.symbolic_shapes import SymTypes
from torch.fx.node import _get_qualified_name
from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
from ..utils import (
@ -734,6 +735,8 @@ class WrapperCodeGen(CodeGen):
return self.ref
return repr(type(s)(Shim(self.val_to_str(a)) for a in s))
elif isinstance(s, torch._ops.OpOverload):
return _get_qualified_name(s)
else:
return repr(s)

View File

@ -3235,6 +3235,13 @@ class FallbackKernel(ExternKernelAlloc):
if V.graph.cpp_wrapper
else f"aten.{op_overload_packet.__name__}"
)
elif isinstance(kernel, torch._ops.HigherOrderOperator):
if getattr(torch._prims.rng_prims, kernel.__name__, None) is kernel:
self.kernel = f"torch._prims.rng_prims.{kernel.__name__}"
else:
raise NotImplementedError(
"Unable to find HigherOrderOperator kernel name"
)
else:
if V.graph.cpp_wrapper:
from torch._inductor.codegen.wrapper import (
@ -3300,6 +3307,24 @@ class FallbackKernel(ExternKernelAlloc):
self.kwargs.update(kwargs)
return args
@staticmethod
def find_device(tensor_args, example_output):
if tensor_args:
return tensor_args[0].get_device()
if isinstance(example_output, torch.Tensor):
return example_output.device
if isinstance(example_output, (list, tuple)):
devices = {FallbackKernel.find_device(None, x) for x in example_output}
# Remove None
devices = [device for device in devices if device]
if len(devices) == 1:
return devices[0]
for device in devices:
if device.type == "cuda":
return device
return devices[0]
return None
def codegen(self, wrapper):
if self.use_cpp_op_schema:
args = [*self.codegen_args(), *self.codegen_kwargs()]
@ -3337,13 +3362,10 @@ class FallbackKernel(ExternKernelAlloc):
unflatten_args,
) = cls.process_kernel(kernel, *args, **kwargs)
assert tensor_args or isinstance(
example_output, torch.Tensor
), "Not sure where to find device info"
device = FallbackKernel.find_device(tensor_args, example_output)
assert device, "Not sure where to find device info"
packed = FallbackKernel(
MultiOutputLayout(
tensor_args[0].get_device() if tensor_args else example_output.device
),
MultiOutputLayout(device),
kernel,
tensor_args,
non_tensor_args,

View File

@ -1634,6 +1634,10 @@ make_fallback(aten.gcd.default, warn=False)
make_fallback(aten._linalg_eigh)
make_fallback(aten.zeros.names)
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
make_fallback(torch._prims.rng_prims.run_with_rng_state)
# fails accuracy on test_torch.py, and explicit fallback required to avoid warn=True on implicit
make_fallback(aten.exponential.default, warn=False)

View File

@ -409,6 +409,8 @@ def has_incompatible_cudagraph_ops(gm):
"aten._fused_moving_avg_obs_fq_helper_functional.default",
"fbgemm.dense_to_jagged.default",
"fbgemm.jagged_to_padded_dense.default",
"run_with_rng_state",
"run_and_save_rng_state",
}
if torch.are_deterministic_algorithms_enabled():
forbidden_set.update(

View File

@ -251,8 +251,11 @@ class HigherOrderOperator(OperatorBase):
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
def __str__(self):
return f"{self.name()}"
def name(self):
return self.name
return self._name
def _to_flat_tuple(args, kwargs):

View File

@ -1,11 +1,27 @@
from functools import partial
from typing import Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch import _prims
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
from torch._prims_common.wrappers import backwards_not_supported
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
unwrap_proxy,
)
from torch.types import _device, _dtype
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
rngprim_namespace = "rngprims"
rngprim = torch.library.Library(rngprim_namespace, "DEF")
@ -132,5 +148,144 @@ def register_philox_rand():
)
def get_device(args, kwargs):
if kwargs.get("device"):
device = kwargs.get("device")
if isinstance(device, str):
device = torch.device(device)
return device.type
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
if any(dev == "cuda" for dev in devices):
return "cuda"
elif any(dev == "cpu" for dev in devices):
return "cpu"
return None
def register_run_and_save_rng_state_op():
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
run_and_save_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
@run_and_save_rng_state.py_impl(DispatchKey.Autograd)
def impl_autograd(op, *args, **kwargs):
with torch._C._AutoDispatchBelowAutograd():
return run_and_save_rng_state(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, **kwargs):
return torch.cuda.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(op, *args, **kwargs):
return torch.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, **kwargs):
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(op, *args, **kwargs):
# Check device to call the right impl
return impl_backend_select(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(op, *args, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is not None
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
out = impl_fake_tensor_mode(op, *args, **kwargs)
proxy_args = pytree.tree_map(partial(unwrap_proxy, mode), (op, *args))
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=mode.tracer
)
else:
return run_and_save_rng_state(op, *args, **kwargs)
return run_and_save_rng_state
def register_run_with_rng_state_op():
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
run_with_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
@run_with_rng_state.py_impl(DispatchKey.Autograd)
def impl_autograd(rng_state, op, *args, **kwargs):
with torch._C._AutoDispatchBelowAutograd():
return run_with_rng_state(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(rng_state, op, *args, **kwargs):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
out = op(*args, **kwargs)
torch.cuda.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(rng_state, op, *args, **kwargs):
current_state = torch.get_rng_state()
torch.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is not None
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
partial(unwrap_proxy, mode), (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=mode.tracer
)
else:
return run_with_rng_state(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
return op(*args, **kwargs)
return run_with_rng_state
run_and_save_rng_state = register_run_and_save_rng_state_op()
run_with_rng_state = register_run_with_rng_state_op()
def register_rng_prims():
register_philox_rand()

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
torch.Tensor, torch.device, torch.memory_format, torch.layout]
torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
Target = Union[Callable[..., Any], str]

View File

@ -352,9 +352,9 @@ def _export_fx_node_to_onnxscript(
elif node.op == "call_function":
# aten ops and other stateless functions.
if node.target == operator.getitem and isinstance(
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr]
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index]
):
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr]
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
index = node.args[1]
output = onnx_tensor_tuple[index] # type: ignore[index]
assert (

View File

@ -372,6 +372,7 @@ BaseArgumentTypes = Union[
torch.device,
torch.memory_format,
torch.layout,
torch._ops.OpOverload,
]
Argument = Optional[
Union[