Compare commits

...

4 Commits

Author SHA1 Message Date
9babde961f Initial test 2024-05-20 18:58:12 -07:00
642981817e Piping metadata 2024-05-15 11:46:59 -07:00
8b1e61d653 Initial cudagraph tree dispatch impl 2024-05-15 10:39:14 -07:00
a3a1aa12da [DONT MERGE][Dynamo] Inline inbuilt nn modules
ghstack-source-id: 65033f535dd1cb2421a79aba8b7b0faab19d3d02
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123986
2024-05-15 10:39:14 -07:00
11 changed files with 110 additions and 31 deletions

View File

@ -364,22 +364,13 @@ def make_test(
scheduler_eager.last_epoch = 1
with torch.set_grad_enabled(False):
for i in range(2):
for i in range(5):
compiled_step()
opt_eager.step()
if scheduler_cls:
call_scheduler(scheduler_eager)
call_scheduler(scheduler_compiled)
check_optim(
self,
optim_cls,
model_eager.parameters(),
model_compiled.parameters(),
opt_eager.state,
opt_compiled.state,
)
if run_cudagraphs:
self.check_cudagraphs_ran()

View File

@ -53,6 +53,17 @@ requires_multigpu = functools.partial(
from io import StringIO
def _debug_print_cudagraph_tree(manager):
def visit(node):
for children in node.children.values():
for child in children:
visit(child)
print(f"parent: {node}, child: {child}")
for node in manager.get_roots():
visit(node)
def get_compile_fn(backend):
if backend == "cudagraphs":
return functools.partial(torch.compile, backend="cudagraphs")
@ -648,7 +659,9 @@ if HAS_CUDA and not TEST_WITH_ASAN:
with mode:
inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
compiled_f = compile_fx_inner(mod, inps, num_fixed=1, cudagraphs=True)
compiled_f = compile_fx_inner(
mod, inps, static_input_idxs=[0], cudagraphs=True
)
def get_unaligned_inputs():
return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
@ -1704,6 +1717,29 @@ if HAS_CUDA and not TEST_WITH_ASAN:
with self.assertRaisesRegex(Exception, "custom error msg"):
device = x.untyped_storage()
@torch._inductor.config.patch("triton.cudagraphs", True)
def test_multiple_dispatch_single_graph(self):
# Verify that we can record multiple cudagraphs for a single
# compiled function
torch.set_default_device("cuda")
@torch.compile(mode="reduce-overhead")
def fn(x, y):
return x * y
p1 = torch.nn.Parameter(torch.ones([2, 2]))
p2 = torch.nn.Parameter(torch.zeros([2, 2]))
for _ in range(5):
res1 = fn(torch.ones(2, 2), p1)
res1.sum().backward()
for _ in range(5):
res2 = fn(torch.ones(2, 2), p2)
res2.sum().backward()
# Fwd + bwd graphs for each version of the function => 4 graphs
self.assertEqual(self.get_manager().new_graph_id().id, 4)
instantiate_parametrized_tests(CudaGraphTreeTests)
if __name__ == "__main__":

View File

@ -137,3 +137,4 @@ def install_generation_tagging_init():
Module.___needs_generation_tag_patch = False # type: ignore[attr-defined]
GenerationTracker.generation += 1
print(f"incrementing generation: {GenerationTracker.generation}")

View File

@ -1198,10 +1198,14 @@ class VariableBuilder:
or get_static_address_type(value) is not None
) and not source.guard_source().is_fsdp_module():
self.assert_not_wrapped_by_this_graph(value)
print("ATTR")
print(source)
return self.tx.output.register_attr_or_module(
value, self.name, source=source
)
print(source)
if is_constant_source(source):
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(

View File

@ -665,6 +665,10 @@ from a multi-output view call"
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
static_parameter_input_indices = [
i for i, arg in enumerate(flat_args) if isinstance(arg, torch.nn.Parameter)
]
f_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)
@ -716,6 +720,7 @@ from a multi-output view call"
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
static_parameter_indices=static_parameter_input_indices,
tokens=mode._tokens,
)
return metadata

View File

@ -304,6 +304,9 @@ class ViewAndMutationMeta:
# raised
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_parameter_indices: List[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
# dictionary telling us how many tokens we will need during tracing.

View File

@ -834,6 +834,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
:attr:`mod`, but with forward and backward graph compiled.
"""
breakpoint()
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)

View File

@ -118,6 +118,19 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
return False
def get_static_input_idxs(num_fixed):
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
# of cudagraphs. Rather than copying these into cudagraph-owned memory
# like we do for normal inputs on each run, we will re-record a cudagraph if these
# parameter locations change.
context = torch._guards.TracingContext.try_get()
return (
list(range(num_fixed)) + context.fw_metadata.static_parameter_indices
if context
else []
)
@functools.lru_cache(None)
def _step_logger():
return dynamo_logging.get_step_logger(log)
@ -342,7 +355,6 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]):
def count_bytes_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
num_fixed: int = 0,
**kwargs,
):
shape_env = _shape_env_from_inputs(example_inputs)
@ -351,7 +363,7 @@ def count_bytes_inner(
with V.set_fake_mode(fake_mode):
_recursive_post_grad_passes(gm, False)
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
graph = GraphLowering(gm, shape_env=shape_env)
with V.set_graph_handler(graph), V.set_real_inputs(
example_inputs
), maybe_disable_comprehensive_padding(example_inputs):
@ -438,7 +450,7 @@ def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed: int = 0,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
@ -449,6 +461,7 @@ def compile_fx_inner(
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
) -> Union[CompiledFxGraph, str]:
print(static_input_idxs)
"""
Inductor API that compiles a single graph.
@ -472,7 +485,7 @@ def compile_fx_inner(
gm,
example_inputs,
cudagraphs=cudagraphs,
num_fixed=num_fixed,
static_input_idxs=static_input_idxs,
is_backward=is_backward,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
@ -491,7 +504,7 @@ def compile_fx_inner(
# of fx_codegen_and_compile changes, the dict should be updated accordingly
graph_kwargs = {
"cudagraphs": cudagraphs,
"num_fixed": num_fixed,
"static_input_idxs": static_input_idxs,
"is_backward": is_backward,
"graph_id": graph_id,
"cpp_wrapper": cpp_wrapper,
@ -564,7 +577,7 @@ def compile_fx_inner(
)
has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor(
gm, compiled_graph, num_fixed
gm, compiled_graph, static_input_idxs
)
has_mutation = has_mutation_str is not None
@ -587,6 +600,7 @@ def compile_fx_inner(
]
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
print("--------------------------------------")
if not cudagraph_fail_reasons:
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
@ -604,7 +618,7 @@ def compile_fx_inner(
compiled_graph.current_callable = cudagraphify(
compiled_graph.current_callable,
example_inputs,
static_input_idxs=range(num_fixed),
static_input_idxs=static_input_idxs,
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
@ -651,7 +665,7 @@ def compile_fx_inner(
# cudagraphs does its own aligning of inputs
if not cudagraphs:
new_callable = align_inputs(
compiled_graph.current_callable, example_inputs, range(num_fixed)
compiled_graph.current_callable, example_inputs, static_input_idxs
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
@ -673,7 +687,7 @@ def fx_codegen_and_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed: int = 0,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
@ -762,7 +776,6 @@ def fx_codegen_and_compile(
const_gm,
example_inputs=[],
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
@ -784,7 +797,6 @@ def fx_codegen_and_compile(
# we currently use fake tensors and defake them later.
example_inputs=example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
@ -1198,6 +1210,7 @@ def fw_compiler_freezing(
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
)
static_input_idxs = list(range(num_fixed))
# constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None:
@ -1207,11 +1220,13 @@ def fw_compiler_freezing(
if i not in preserved_arg_indices:
params_flat[i] = None
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
opt_model,
aot_example_inputs,
num_fixed=num_fixed,
static_input_idxs=static_input_idxs,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=True,
@ -1342,7 +1357,8 @@ def compile_fx(
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
user_visible_outputs = {}
user_visible_outputs = set()
if config.keep_output_stride:
model_outputs_node = output_node(model)
@ -1397,7 +1413,7 @@ def compile_fx(
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=is_inference,
@ -1441,7 +1457,7 @@ def compile_fx(
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs,
is_backward=True,
graph_id=graph_id,

View File

@ -354,6 +354,9 @@ def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
def deferred_cudagraphify(inputs):
int_key = get_ints(inputs)
fn = fn_cache.get(int_key)
print("CACHED")
print(int_key)
print(fn)
if fn is not None:
return fn(inputs)
@ -364,6 +367,7 @@ def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
# first get indices we need to check to align, then update our static inputs,
# and finally copy
print(static_input_idxs)
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
copy_misaligned_inputs(inputs, check_input_idxs)
@ -392,6 +396,7 @@ def cudagraphify(
):
manager = get_container(device_index).get_tree_manager()
assert not (is_backward and is_inference)
print(f"inference:{is_inference}")
mode = (
CompilationMode.BACKWARD
if is_backward
@ -643,6 +648,7 @@ class CUDAWarmupNode:
out_refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
print("ran warmed up node")
return out
@property
@ -1541,6 +1547,8 @@ class CUDAGraphNode:
Checks if this node can be run. The same pattern of tensor liveness and tensors
managed in the cudagraph private pool must remain stable.
"""
# breakpoint()
print(self.static_input_data_ptrs)
# previously managed data pointers remain stable
# this is on the hot path so moved to C++. equivalent to:
@ -1550,6 +1558,12 @@ class CUDAGraphNode:
):
return False
# static input data pointers remain stable
if not torch._C._tensors_data_ptrs_at_indices_equal(
inputs, self.static_input_data_ptrs, self.static_input_idxs
):
return False
if not self._check_liveness(
self.expected_dead_indices_before_graph, self.path_weakrefs
):
@ -1848,6 +1862,8 @@ class CUDAGraphTreeManager:
raise RuntimeError(f"Unknown node type {type(self.current_node)}")
def _run(self, new_inputs: List[Tensor], function_id: FunctionID):
print(self.path_state)
print(f"current_gen: {self.current_gen}")
# we will try to end the current execution lazily, since
# we dont want to do unnecessary checking of the existing outputs
# on the hot path, but both recording and warmup only happen once
@ -1873,6 +1889,7 @@ class CUDAGraphTreeManager:
# then warm up graph B and make more allocations, the subsequent recording of A will not
# necessarily use the same addresses as in the warm up. Thus any warm up of a node can only
# be followed by warm up runs.
print(self.warmed_up_functions)
if (
(
not (
@ -1888,6 +1905,7 @@ class CUDAGraphTreeManager:
if self.path_state == ExecutionState.EXECUTION:
self.apply_checkpoint_execution_state_in_allocator()
print("running eager")
return self.run_eager(new_inputs, function_id)
child_nodes = (
@ -1901,6 +1919,8 @@ class CUDAGraphTreeManager:
# and other
if child.check_invariants(new_inputs):
return self.execute_node(child, new_inputs)
else:
print("NO CHILD MATCH")
# now that we know the new function can't be run as a child of the
# current node, if it is a root, try to end the current execution.
@ -2040,6 +2060,7 @@ class CUDAGraphTreeManager:
placeholders,
mutated_input_idxs,
)
print(self.ids_to_funcs)
self.id_to_mode[id] = mode
fn = functools.partial(self.run, function_id=id)
@ -2135,6 +2156,7 @@ class CUDAGraphTreeManager:
self.clear_current_path_state_and_set_to_none()
def try_end_curr_warmup(self, function_id: FunctionID):
# breakpoint()
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.current_node = None

View File

@ -143,15 +143,18 @@ class BoxedDeviceIndex:
def check_for_mutation_ignore_cuda_graph_managed_tensor(
gm: torch.fx.GraphModule, compiled_graph, num_fixed: int
gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: List[int]
) -> Optional[str]:
default_msg = format_default_skip_message("mutated inputs")
# doesnt work for non-trees because the warmup run would apply mutation twice
if torch._inductor.config.triton.cudagraph_trees:
static_input_idxs = set(static_input_idxs)
# checking if mutation is only on parameters/static inputs
mutation_indices = [
idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
idx
for idx in compiled_graph.mutated_input_idxs
if idx not in static_input_idxs
]
has_mutation = len(mutation_indices) != 0
if not has_mutation:

View File

@ -296,7 +296,6 @@ class GraphLowering(torch.fx.Interpreter):
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
num_static_inputs=None,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
@ -311,7 +310,6 @@ class GraphLowering(torch.fx.Interpreter):
name=None,
):
super().__init__(gm)
self.example_inputs = example_inputs
self.layout_opt = (
layout_opt
@ -374,7 +372,6 @@ class GraphLowering(torch.fx.Interpreter):
Callable[[List[ir.ExternKernelNode]], Any]
] = extern_node_serializer
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.num_static_inputs = num_static_inputs
self.lists: Dict[str, List[str]] = {}
self.mutated_inputs: Set[str] = set()
self.mutated_input_idxs: List[int] = []