Compare commits

...

3 Commits

Author SHA1 Message Date
365f214867 Allow multiple cudagraph recordings per compiled graph
ghstack-source-id: 9d946b20f4defdf57ef3be32a2747dceccce97e5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126822
2024-06-05 12:51:50 -07:00
d089aa03e0 Remove unused arg to GraphLowering
ghstack-source-id: 0df0c3537abc6042338153845e25ea43c00a38c7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126821
2024-06-04 17:08:55 -07:00
15cba86750 Collect static parameter metadata in aot
ghstack-source-id: 41acbb28498e97a0dd03fb60b35002f82bcd7b16
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126820
2024-06-04 17:08:55 -07:00
9 changed files with 222 additions and 27 deletions

View File

@ -2232,7 +2232,6 @@ class CPUReproTests(TestCase):
graph_lowering = GraphLowering(
torch.fx.GraphModule(submodules, _graph),
shape_env=None,
num_static_inputs=0,
)
def set_opt_dtype(graph):
@ -2343,7 +2342,6 @@ class CPUReproTests(TestCase):
graph_lowering = GraphLowering(
torch.fx.GraphModule(submodules, _graph),
shape_env=None,
num_static_inputs=0,
)
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
graph_lowering

View File

@ -648,7 +648,9 @@ if HAS_CUDA and not TEST_WITH_ASAN:
with mode:
inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
compiled_f = compile_fx_inner(mod, inps, num_fixed=1, cudagraphs=True)
compiled_f = compile_fx_inner(
mod, inps, static_input_idxs=[0], cudagraphs=True
)
def get_unaligned_inputs():
return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
@ -1770,6 +1772,148 @@ if HAS_CUDA and not TEST_WITH_ASAN:
[foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]
)
def run_static_input_param_test(self, fn_eager, num_graphs):
with torch.device("cuda"):
fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")
def run_iter(param, fn):
fwd_output = fn(torch.ones(2, 2), param)
fwd_output.sum().backward()
grad_output = param.grad.clone().detach()
param.grad = None
return fwd_output, grad_output
def loop(param):
exp_output, exp_grad = run_iter(param, fn_eager)
for _ in range(5):
compiled_output, compiled_grad = run_iter(param, fn_compiled)
self.assertEqual(exp_output, compiled_output)
self.assertEqual(exp_grad, compiled_grad)
p1 = torch.nn.Parameter(torch.rand([2, 2]))
loop(p1)
p2 = torch.nn.Parameter(torch.rand([2, 2]))
loop(p2)
# Run p1 again to ensure we reuse the previous recording
loop(p1)
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
def _module_test(self, mod):
with torch.device("cuda"):
def fn(x, mod):
return mod(x)
fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True)
def run_test_iter(mod, fn):
fwd_output = fn(torch.ones(2, 2), mod)
fwd_output.sum().backward()
grad_output = mod.weight.grad.clone().detach()
mod.zero_grad()
return fwd_output, grad_output
def run_test():
exp_output, exp_grad = run_test_iter(mod, fn)
for _ in range(5):
compiled_output, compiled_grad = run_test_iter(mod, fn_compiled)
self.assertEqual(exp_output, compiled_output)
self.assertEqual(exp_grad, compiled_grad)
run_test()
old = mod.weight.data
mod.weight.data = torch.rand_like(mod.weight.data)
run_test()
# Run original version to verify we reuse the other recording
mod.weight.data = old
run_test()
# Fwd + bwd graphs for each version of the function => 4 graphs
self.assertEqual(self.get_manager().new_graph_id().id, 4)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_param_inputs(self):
# Verify that we can record multiple cudagraphs for a single
# compiled function with param inputs
def fn(x, y):
return x * y
# Fwd + bwd graphs for each version of the function => 4 graphs
self.run_static_input_param_test(fn, 4)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_builtin_module(self):
# Verify that we don't recompile when changing the param of a builtin module
# and that we record another cudagraph
# Note: Linear is a builtin module so we enable that config setting above
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module(self):
# Test that we can correctly dispatch multiple graphs
# if params of a custom module change
class TestModule(torch.nn.Module):
def __init__(self, param) -> None:
super().__init__()
self.weight = param
def forward(self, x):
return self.weight * x
self._module_test(
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_child_node(self):
# Test that we can correctly dispatch multiple graphs if a child node
# in the tree has stable input pointers change
def fn(x, p):
# Graph 1
y = x * x
torch._dynamo.graph_break()
# Graph 2
return y * p
# We have 5 graphs here
# Graph 1
# / \
# Graph 2 w/ p1 Graph 2 w/ p2
# and then two backward graphs
self.run_static_input_param_test(fn, 5)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_parent_node(self):
def fn(x, p):
# Graph 1
y = x * p
torch._dynamo.graph_break()
# Graph 2
return y + x
# We have 6 graphs here
# Graph 1 w/ p1 Graph 1 w/ p2
# | |
# Graph 2 (v1) Graph 2 (v2)
# There are two versions of graph 2 because
# we re-record due to different memory state after running the
# two versions of Graph 1
# and then two backward graphs
self.run_static_input_param_test(fn, 6)
instantiate_parametrized_tests(CudaGraphTreeTests)
if __name__ == "__main__":

View File

@ -9190,7 +9190,6 @@ class CommonTemplate:
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=0,
)
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
graph.run(*example_inputs)
@ -10417,7 +10416,6 @@ if HAS_GPU and not TEST_WITH_ASAN:
cxt = TritonCodeGenTests.NoOpCompilerBackend()
torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args)
graph = GraphLowering(cxt.model)
graph.num_static_inputs = 0
kernels = []
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
graph.run(*(cxt.example_args))

View File

@ -665,6 +665,15 @@ from a multi-output view call"
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
if torch._dynamo.config.inline_inbuilt_nn_modules:
static_parameter_input_indices = [
i
for i, arg in enumerate(flat_args)
if isinstance(arg, torch.nn.Parameter)
]
else:
static_parameter_input_indices = []
f_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)
@ -716,6 +725,7 @@ from a multi-output view call"
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
static_parameter_indices=static_parameter_input_indices,
tokens=mode._tokens,
)
return metadata

View File

@ -304,6 +304,9 @@ class ViewAndMutationMeta:
# raised
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_parameter_indices: List[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
# dictionary telling us how many tokens we will need during tracing.

View File

@ -120,6 +120,19 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
return False
def get_static_input_idxs(num_fixed):
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
# of cudagraphs. Rather than copying these into cudagraph-owned memory
# like we do for normal inputs on each run, we will re-record a cudagraph if these
# parameter locations change.
context = torch._guards.TracingContext.try_get()
fixed = list(range(num_fixed))
if not context or not context.fw_metadata:
return fixed
return fixed + context.fw_metadata.static_parameter_indices
@functools.lru_cache(None)
def _step_logger():
return dynamo_logging.get_step_logger(log)
@ -415,7 +428,7 @@ def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed: int = 0,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
@ -440,6 +453,9 @@ def compile_fx_inner(
_LazyGraphModule.force_recompile(gm)
return make_boxed_func(gm.forward)
if static_input_idxs is None:
static_input_idxs = []
assert isinstance(
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
@ -449,7 +465,7 @@ def compile_fx_inner(
gm,
example_inputs,
cudagraphs=cudagraphs,
num_fixed=num_fixed,
static_input_idxs=static_input_idxs,
is_backward=is_backward,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
@ -468,7 +484,7 @@ def compile_fx_inner(
# of fx_codegen_and_compile changes, the dict should be updated accordingly
graph_kwargs = {
"cudagraphs": cudagraphs,
"num_fixed": num_fixed,
"static_input_idxs": static_input_idxs,
"is_backward": is_backward,
"graph_id": graph_id,
"cpp_wrapper": cpp_wrapper,
@ -482,7 +498,7 @@ def compile_fx_inner(
start = time.time()
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
inputs_to_check = get_input_idxs_to_check(example_inputs, range(num_fixed))
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
if (
not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache)
@ -492,7 +508,7 @@ def compile_fx_inner(
if (
isinstance(input, torch.Tensor)
and input.device.type == "cuda"
and i < num_fixed
and i in static_input_idxs
):
input._is_inductor_static = True # type: ignore[attr-defined]
@ -551,7 +567,7 @@ def compile_fx_inner(
)
has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor(
gm, compiled_graph, num_fixed
gm, compiled_graph, static_input_idxs
)
has_mutation = has_mutation_str is not None
@ -591,7 +607,7 @@ def compile_fx_inner(
compiled_graph.current_callable = cudagraphify(
compiled_graph.current_callable,
example_inputs,
static_input_idxs=range(num_fixed),
static_input_idxs=static_input_idxs,
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
@ -660,7 +676,7 @@ def fx_codegen_and_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed: int = 0,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
@ -749,7 +765,6 @@ def fx_codegen_and_compile(
const_gm,
example_inputs=[],
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
@ -771,7 +786,6 @@ def fx_codegen_and_compile(
# we currently use fake tensors and defake them later.
example_inputs=example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
@ -1180,6 +1194,7 @@ def fw_compiler_freezing(
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
)
static_input_idxs = list(range(num_fixed))
# constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None:
@ -1189,11 +1204,14 @@ def fw_compiler_freezing(
if i not in preserved_arg_indices:
params_flat[i] = None
if tracing_context.fw_metadata:
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
opt_model,
aot_example_inputs,
num_fixed=num_fixed,
static_input_idxs=static_input_idxs,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=True,
@ -1324,6 +1342,7 @@ def compile_fx(
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
user_visible_outputs = {}
if config.keep_output_stride:
@ -1379,7 +1398,7 @@ def compile_fx(
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=is_inference,
@ -1423,7 +1442,7 @@ def compile_fx(
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs,
is_backward=True,
graph_id=graph_id,

View File

@ -753,6 +753,11 @@ class CUDAGraphNode:
self.device = device_index
self.stack_traces = stack_traces
self.stream = stream
# If we are inlining builtin nn modules we will re-record if static inputs change
# if not we should error because dynamo should have recompiled in this case
self.rerecord_if_static_inputs_change = (
torch._dynamo.config.inline_inbuilt_nn_modules
)
# if this is a root parent will be None. use weakref to prevent reference cycle
self._parent = weakref.ref(parent) if parent is not None else None
@ -952,8 +957,13 @@ class CUDAGraphNode:
def check_static_inputs_are_stable(self, new_inputs):
# avoid checking managed tensor static points since we already checked those in check_invariants
if not torch._C._tensors_data_ptrs_at_indices_equal(
new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs
if (
not self.rerecord_if_static_inputs_change
and not torch._C._tensors_data_ptrs_at_indices_equal(
new_inputs,
self.static_input_data_ptrs,
self.non_managed_static_input_idxs,
)
):
# this should error
static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs]
@ -1000,6 +1010,9 @@ class CUDAGraphNode:
if config.triton.force_cudagraph_sync:
torch.cuda.synchronize()
# Reset this to run the check in the future
self.static_inputs_stable = False
return outputs
def reconstruct_outputs(self):
@ -1553,8 +1566,8 @@ class CUDAGraphNode:
def check_invariants(self, inputs: List[Tensor]) -> bool:
"""
Checks if this node can be run. The same pattern of tensor liveness and tensors
managed in the cudagraph private pool must remain stable.
Checks if this node can be run. The same pattern of tensor liveness, static inputs,
and tensors managed in the cudagraph private pool must remain stable.
"""
# previously managed data pointers remain stable
@ -1565,6 +1578,18 @@ class CUDAGraphNode:
):
return False
# static input data pointers should remain stable
# if we are inlining builtin nn modules we re-record in this case
# if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable
# and error if they are not stable
if (
self.rerecord_if_static_inputs_change
and not torch._C._tensors_data_ptrs_at_indices_equal(
inputs, self.static_input_data_ptrs, self.static_input_idxs
)
):
return False
if not self._check_liveness(
self.expected_dead_indices_before_graph, self.path_weakrefs
):

View File

@ -143,15 +143,16 @@ class BoxedDeviceIndex:
def check_for_mutation_ignore_cuda_graph_managed_tensor(
gm: torch.fx.GraphModule, compiled_graph, num_fixed: int
gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: List[int]
) -> Optional[str]:
default_msg = format_default_skip_message("mutated inputs")
# doesnt work for non-trees because the warmup run would apply mutation twice
if torch._inductor.config.triton.cudagraph_trees:
unique_idxs = set(static_input_idxs)
# checking if mutation is only on parameters/static inputs
mutation_indices = [
idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
]
has_mutation = len(mutation_indices) != 0
if not has_mutation:

View File

@ -296,7 +296,6 @@ class GraphLowering(torch.fx.Interpreter):
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
num_static_inputs=None,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
@ -311,7 +310,6 @@ class GraphLowering(torch.fx.Interpreter):
name=None,
):
super().__init__(gm)
self.example_inputs = example_inputs
self.layout_opt = (
layout_opt
@ -374,7 +372,6 @@ class GraphLowering(torch.fx.Interpreter):
Callable[[List[ir.ExternKernelNode]], Any]
] = extern_node_serializer
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.num_static_inputs = num_static_inputs
self.lists: Dict[str, List[str]] = {}
self.mutated_inputs: Set[str] = set()
self.mutated_input_idxs: List[int] = []