Files
pytorch/test/test_hop_infra.py
eellison 481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00

93 lines
3.1 KiB
Python

# Owner(s): ["module: higher order operators"]
import importlib
import pkgutil
import torch
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.hop_db import (
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
hop_db,
)
def do_imports():
for mod in pkgutil.walk_packages(
torch._higher_order_ops.__path__, "torch._higher_order_ops."
):
modname = mod.name
importlib.import_module(modname)
do_imports()
@skipIfTorchDynamo("not applicable")
class TestHOPInfra(TestCase):
def test_all_hops_have_opinfo(self):
"""All HOPs should have an OpInfo in torch/testing/_internal/hop_db.py"""
from torch._ops import _higher_order_ops
hops_that_have_op_info = {k.name for k in hop_db}
all_hops = _higher_order_ops.keys()
missing_ops = set()
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in FIXME_hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.add(op)
self.assertTrue(
len(missing_ops) == 0,
f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py",
)
def test_all_hops_are_imported(self):
"""All HOPs should be listed in torch._higher_order_ops.__all__
Some constraints (see test_testing.py::TestImports)
- Sympy must be lazily imported
- Dynamo must be lazily imported
"""
imported_hops = torch._higher_order_ops.__all__
registered_hops = torch._ops._higher_order_ops.keys()
# Please don't add anything here.
# We want to ensure that all HOPs are imported at "import torch" time.
# It is bad if someone tries to access torch.ops.higher_order.cond
# and it doesn't exist (this may happen if your HOP isn't imported at
# "import torch" time).
FIXME_ALLOWLIST = {
"autograd_function_apply",
"run_with_rng_state",
"graphsafe_run_with_rng_state",
"map_impl",
"_export_tracepoint",
"run_and_save_rng_state",
"map",
"custom_function_call",
"trace_wrapped",
"triton_kernel_wrapper_functional",
"triton_kernel_wrapper_mutation",
"wrap", # Really weird failure -- importing this causes Dynamo to choke on checkpoint
}
not_imported_hops = registered_hops - imported_hops
not_imported_hops = not_imported_hops - FIXME_ALLOWLIST
self.assertEqual(
not_imported_hops,
set(),
msg="All HOPs must be listed under torch/_higher_order_ops/__init__.py's __all__.",
)
def test_imports_from_all_work(self):
"""All APIs listed in torch._higher_order_ops.__all__ must be importable"""
stuff = torch._higher_order_ops.__all__
for attr in stuff:
getattr(torch._higher_order_ops, attr)
if __name__ == "__main__":
run_tests()