mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[activation checkpointing] Higher order functional rng op wrappers (#102934)
Introduces two higher order operators * run_and_save_rng_state - Saves the current rng state and then runs the op. * run_with_rng_state - Runs the op with the rng state supplied as an input Ideally, we would like to use torch.compile for these operators. But currently the plan is to introduce these operators at the partitioner level, obviating the need to support them fully through the torch.compile stack. To ensure that we have good enough debugging with minifiers, we have ensure that they work with make_fx. In future, we can move on torch.compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102934 Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
31ee1512d3
commit
58d2c66a70
@ -2566,6 +2566,7 @@ class TestPartitioning(AOTTestCase):
|
||||
res = aot_mod(x)
|
||||
res.sum().backward()
|
||||
|
||||
|
||||
class TestAOTModuleSimplified(AOTTestCase):
|
||||
def test_aot_module_simplified(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
|
@ -4819,6 +4819,50 @@ class CommonTemplate:
|
||||
self.assertTrue((d >= 0).all())
|
||||
self.assertTrue((d < 1).all())
|
||||
|
||||
def test_functionalize_rng_wrappers(self):
|
||||
# Ideally, we would like to use torch.compile for these operators. But
|
||||
# currently the plan is to introduce these operators at the partitioner
|
||||
# level, obviating the need to support them fully through the
|
||||
# torch.compile stack. To ensure that we have good enough debugging with
|
||||
# minifiers, we have ensure that they work with make_fx. This test uses
|
||||
# make_fx to do the testing. In future, we can move on torch.compile.
|
||||
def fn():
|
||||
rng_state1, a1 = torch._prims.rng_prims.run_and_save_rng_state(
|
||||
torch.ops.aten.rand.default,
|
||||
[4, 4],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
rng_state2, a2 = torch._prims.rng_prims.run_and_save_rng_state(
|
||||
torch.ops.aten.rand.default,
|
||||
[4, 4],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
b1 = torch._prims.rng_prims.run_with_rng_state(
|
||||
rng_state1,
|
||||
torch.ops.aten.rand.default,
|
||||
[4, 4],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
b2 = torch._prims.rng_prims.run_with_rng_state(
|
||||
rng_state2,
|
||||
torch.ops.aten.rand.default,
|
||||
[4, 4],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (a1, a2, b1, b2)
|
||||
|
||||
mod = make_fx(fn)()
|
||||
compiled_f = compile_fx_inner(mod, ())
|
||||
a1, a2, b1, b2 = compiled_f(())
|
||||
self.assertEqual(a1, b1)
|
||||
self.assertEqual(a2, b2)
|
||||
|
||||
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
||||
def test_philox_rand(self):
|
||||
if self.device == "cpu":
|
||||
|
@ -1124,6 +1124,27 @@ class TestPrims(TestCase):
|
||||
for a, b in zip(references, results):
|
||||
self.assertEqual(a, b)
|
||||
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_functional_rng_wrappers(self, device, dtype):
|
||||
|
||||
torch.manual_seed(123)
|
||||
ref1 = torch.rand(10, device=device, dtype=dtype)
|
||||
ref2 = torch.rand(10, device=device, dtype=dtype)
|
||||
|
||||
|
||||
torch.manual_seed(123)
|
||||
rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
|
||||
rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
|
||||
|
||||
res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
|
||||
res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
|
||||
|
||||
self.assertEqual(ref1, res1)
|
||||
self.assertEqual(ref2, res2)
|
||||
self.assertEqual(ref1, res3)
|
||||
self.assertEqual(ref2, res4)
|
||||
|
||||
class TestPrimsBasic(TestCase):
|
||||
def test_torch_ops(self):
|
||||
r = make_tensor((2,), device='cpu', dtype=torch.float)
|
||||
|
@ -14,6 +14,7 @@ from sympy import Expr
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch.fx.experimental.symbolic_shapes import SymTypes
|
||||
from torch.fx.node import _get_qualified_name
|
||||
from .. import codecache, config, ir
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..utils import (
|
||||
@ -734,6 +735,8 @@ class WrapperCodeGen(CodeGen):
|
||||
return self.ref
|
||||
|
||||
return repr(type(s)(Shim(self.val_to_str(a)) for a in s))
|
||||
elif isinstance(s, torch._ops.OpOverload):
|
||||
return _get_qualified_name(s)
|
||||
else:
|
||||
return repr(s)
|
||||
|
||||
|
@ -3235,6 +3235,13 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
if V.graph.cpp_wrapper
|
||||
else f"aten.{op_overload_packet.__name__}"
|
||||
)
|
||||
elif isinstance(kernel, torch._ops.HigherOrderOperator):
|
||||
if getattr(torch._prims.rng_prims, kernel.__name__, None) is kernel:
|
||||
self.kernel = f"torch._prims.rng_prims.{kernel.__name__}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unable to find HigherOrderOperator kernel name"
|
||||
)
|
||||
else:
|
||||
if V.graph.cpp_wrapper:
|
||||
from torch._inductor.codegen.wrapper import (
|
||||
@ -3300,6 +3307,24 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
self.kwargs.update(kwargs)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def find_device(tensor_args, example_output):
|
||||
if tensor_args:
|
||||
return tensor_args[0].get_device()
|
||||
if isinstance(example_output, torch.Tensor):
|
||||
return example_output.device
|
||||
if isinstance(example_output, (list, tuple)):
|
||||
devices = {FallbackKernel.find_device(None, x) for x in example_output}
|
||||
# Remove None
|
||||
devices = [device for device in devices if device]
|
||||
if len(devices) == 1:
|
||||
return devices[0]
|
||||
for device in devices:
|
||||
if device.type == "cuda":
|
||||
return device
|
||||
return devices[0]
|
||||
return None
|
||||
|
||||
def codegen(self, wrapper):
|
||||
if self.use_cpp_op_schema:
|
||||
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
||||
@ -3337,13 +3362,10 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
unflatten_args,
|
||||
) = cls.process_kernel(kernel, *args, **kwargs)
|
||||
|
||||
assert tensor_args or isinstance(
|
||||
example_output, torch.Tensor
|
||||
), "Not sure where to find device info"
|
||||
device = FallbackKernel.find_device(tensor_args, example_output)
|
||||
assert device, "Not sure where to find device info"
|
||||
packed = FallbackKernel(
|
||||
MultiOutputLayout(
|
||||
tensor_args[0].get_device() if tensor_args else example_output.device
|
||||
),
|
||||
MultiOutputLayout(device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
|
@ -1634,6 +1634,10 @@ make_fallback(aten.gcd.default, warn=False)
|
||||
make_fallback(aten._linalg_eigh)
|
||||
make_fallback(aten.zeros.names)
|
||||
|
||||
|
||||
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
|
||||
make_fallback(torch._prims.rng_prims.run_with_rng_state)
|
||||
|
||||
# fails accuracy on test_torch.py, and explicit fallback required to avoid warn=True on implicit
|
||||
make_fallback(aten.exponential.default, warn=False)
|
||||
|
||||
|
@ -409,6 +409,8 @@ def has_incompatible_cudagraph_ops(gm):
|
||||
"aten._fused_moving_avg_obs_fq_helper_functional.default",
|
||||
"fbgemm.dense_to_jagged.default",
|
||||
"fbgemm.jagged_to_padded_dense.default",
|
||||
"run_with_rng_state",
|
||||
"run_and_save_rng_state",
|
||||
}
|
||||
if torch.are_deterministic_algorithms_enabled():
|
||||
forbidden_set.update(
|
||||
|
@ -251,8 +251,11 @@ class HigherOrderOperator(OperatorBase):
|
||||
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
|
||||
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name()}"
|
||||
|
||||
def name(self):
|
||||
return self.name
|
||||
return self._name
|
||||
|
||||
|
||||
def _to_flat_tuple(args, kwargs):
|
||||
|
@ -1,11 +1,27 @@
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import _prims
|
||||
from torch._C import DispatchKey
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
||||
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
|
||||
from torch._prims_common.wrappers import backwards_not_supported
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
unwrap_proxy,
|
||||
)
|
||||
from torch.types import _device, _dtype
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_pop_mode_temporarily,
|
||||
)
|
||||
|
||||
|
||||
rngprim_namespace = "rngprims"
|
||||
rngprim = torch.library.Library(rngprim_namespace, "DEF")
|
||||
@ -132,5 +148,144 @@ def register_philox_rand():
|
||||
)
|
||||
|
||||
|
||||
def get_device(args, kwargs):
|
||||
if kwargs.get("device"):
|
||||
device = kwargs.get("device")
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
return device.type
|
||||
|
||||
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
|
||||
if any(dev == "cuda" for dev in devices):
|
||||
return "cuda"
|
||||
elif any(dev == "cpu" for dev in devices):
|
||||
return "cpu"
|
||||
return None
|
||||
|
||||
|
||||
def register_run_and_save_rng_state_op():
|
||||
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
|
||||
|
||||
run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
||||
run_and_save_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
||||
|
||||
@run_and_save_rng_state.py_impl(DispatchKey.Autograd)
|
||||
def impl_autograd(op, *args, **kwargs):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return run_and_save_rng_state(op, *args, **kwargs)
|
||||
|
||||
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
|
||||
def impl_cuda(op, *args, **kwargs):
|
||||
return torch.cuda.get_rng_state(), op(*args, **kwargs)
|
||||
|
||||
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
|
||||
def impl_cpu(op, *args, **kwargs):
|
||||
return torch.get_rng_state(), op(*args, **kwargs)
|
||||
|
||||
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
|
||||
def impl_backend_select(op, *args, **kwargs):
|
||||
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
|
||||
device = get_device(args, kwargs)
|
||||
assert device in impl_map, f"Backend not supported for {device}"
|
||||
impl = impl_map[device]
|
||||
return impl(op, *args, **kwargs)
|
||||
|
||||
@run_and_save_rng_state.py_impl(FakeTensorMode)
|
||||
def impl_fake_tensor_mode(op, *args, **kwargs):
|
||||
# Check device to call the right impl
|
||||
return impl_backend_select(op, *args, **kwargs)
|
||||
|
||||
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
|
||||
def impl_proxy_dispatch_mode(op, *args, **kwargs):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is not None
|
||||
with _pop_mode_temporarily() as mode:
|
||||
if mode.enable_tracing:
|
||||
out = impl_fake_tensor_mode(op, *args, **kwargs)
|
||||
proxy_args = pytree.tree_map(partial(unwrap_proxy, mode), (op, *args))
|
||||
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
|
||||
)
|
||||
return track_tensor_tree(
|
||||
out, out_proxy, constant=None, tracer=mode.tracer
|
||||
)
|
||||
else:
|
||||
return run_and_save_rng_state(op, *args, **kwargs)
|
||||
|
||||
return run_and_save_rng_state
|
||||
|
||||
|
||||
def register_run_with_rng_state_op():
|
||||
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
|
||||
|
||||
run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
||||
run_with_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
||||
|
||||
@run_with_rng_state.py_impl(DispatchKey.Autograd)
|
||||
def impl_autograd(rng_state, op, *args, **kwargs):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return run_with_rng_state(rng_state, op, *args, **kwargs)
|
||||
|
||||
@run_with_rng_state.py_impl(DispatchKey.CUDA)
|
||||
def impl_cuda(rng_state, op, *args, **kwargs):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state.cpu())
|
||||
out = op(*args, **kwargs)
|
||||
torch.cuda.set_rng_state(current_state)
|
||||
return out
|
||||
|
||||
@run_with_rng_state.py_impl(DispatchKey.CPU)
|
||||
def impl_cpu(rng_state, op, *args, **kwargs):
|
||||
current_state = torch.get_rng_state()
|
||||
torch.set_rng_state(rng_state)
|
||||
out = op(*args, **kwargs)
|
||||
torch.set_rng_state(current_state)
|
||||
return out
|
||||
|
||||
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
|
||||
def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is not None
|
||||
with _pop_mode_temporarily() as mode:
|
||||
if mode.enable_tracing:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = run_with_rng_state(rng_state, op, *args, **kwargs)
|
||||
proxy_args = pytree.tree_map(
|
||||
partial(unwrap_proxy, mode), (rng_state, op, *args)
|
||||
)
|
||||
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
|
||||
)
|
||||
return track_tensor_tree(
|
||||
out, out_proxy, constant=None, tracer=mode.tracer
|
||||
)
|
||||
else:
|
||||
return run_with_rng_state(rng_state, op, *args, **kwargs)
|
||||
|
||||
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
|
||||
def impl_backend_select(rng_state, op, *args, **kwargs):
|
||||
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
|
||||
device = get_device(args, kwargs)
|
||||
assert device in impl_map, f"Backend not supported for {device}"
|
||||
impl = impl_map[device]
|
||||
return impl(rng_state, op, *args, **kwargs)
|
||||
|
||||
@run_with_rng_state.py_impl(FakeTensorMode)
|
||||
def impl_fake_tensor_mode(rng_state, op, *args, **kwargs):
|
||||
# Skip setting the set_rng_state as it does not work well with fake tensors.
|
||||
# And it does not matter for the fake tensor mode.
|
||||
return op(*args, **kwargs)
|
||||
|
||||
return run_with_rng_state
|
||||
|
||||
|
||||
run_and_save_rng_state = register_run_and_save_rng_state_op()
|
||||
run_with_rng_state = register_run_with_rng_state_op()
|
||||
|
||||
|
||||
def register_rng_prims():
|
||||
register_philox_rand()
|
||||
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
|
||||
|
||||
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
|
||||
torch.Tensor, torch.device, torch.memory_format, torch.layout]
|
||||
torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload]
|
||||
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
|
@ -352,9 +352,9 @@ def _export_fx_node_to_onnxscript(
|
||||
elif node.op == "call_function":
|
||||
# aten ops and other stateless functions.
|
||||
if node.target == operator.getitem and isinstance(
|
||||
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr]
|
||||
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index]
|
||||
):
|
||||
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr]
|
||||
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
|
||||
index = node.args[1]
|
||||
output = onnx_tensor_tuple[index] # type: ignore[index]
|
||||
assert (
|
||||
|
@ -372,6 +372,7 @@ BaseArgumentTypes = Union[
|
||||
torch.device,
|
||||
torch.memory_format,
|
||||
torch.layout,
|
||||
torch._ops.OpOverload,
|
||||
]
|
||||
Argument = Optional[
|
||||
Union[
|
||||
|
Reference in New Issue
Block a user