Compare commits

...

1 Commits

Author SHA1 Message Date
a9cb7a187b Initial commit 2024-05-15 17:07:53 -07:00
5 changed files with 56 additions and 10 deletions

View File

@ -364,22 +364,13 @@ def make_test(
scheduler_eager.last_epoch = 1
with torch.set_grad_enabled(False):
for i in range(2):
for i in range(5):
compiled_step()
opt_eager.step()
if scheduler_cls:
call_scheduler(scheduler_eager)
call_scheduler(scheduler_compiled)
check_optim(
self,
optim_cls,
model_eager.parameters(),
model_compiled.parameters(),
opt_eager.state,
opt_compiled.state,
)
if run_cudagraphs:
self.check_cudagraphs_ran()

View File

@ -1704,6 +1704,43 @@ if HAS_CUDA and not TEST_WITH_ASAN:
with self.assertRaisesRegex(Exception, "custom error msg"):
device = x.untyped_storage()
@torch._inductor.config.patch("triton.cudagraphs", True)
def test_multiple_dispatch(self):
torch.set_default_device("cuda")
@torch.compile
def fn(x, y):
return x * y
p1 = torch.nn.Parameter(torch.ones([2, 2]))
p2 = torch.nn.Parameter(torch.zeros([2, 2]))
torch._dynamo.decorators.mark_static_address(p1)
# res1 = fn(torch.ones(2, 2), torch.ones(2, 2))
# res1 = fn(torch.ones(2, 2), torch.ones(2, 2))
# res1 = fn(torch.ones(2, 2), torch.ones(2, 2))
print("start call 1")
res1 = fn(torch.ones(2, 2), p1)
print("end call 1")
print("start call 2")
res1 = fn(torch.ones(2, 2), p1)
print("end call 2")
print("start call 3")
res1 = fn(torch.ones(2, 2), p1)
print("end call 3")
print("start call 4")
res1 = fn(torch.ones(2, 2), p1)
print("end call 4")
print("start call 5")
res1 = fn(torch.ones(2, 2), p1)
print("end call 5")
# res1 = fn(torch.ones(2, 2), p1)
# res2 = fn(torch.ones(2, 2), p2)
# res2 = fn(torch.ones(2, 2), p2)
# res2 = fn(torch.ones(2, 2), p2)
# self.assertNotEqual(res1, res2)
self.assertEqual(self.get_manager().new_graph_id().id, 2)
instantiate_parametrized_tests(CudaGraphTreeTests)
if __name__ == "__main__":

View File

@ -1197,6 +1197,7 @@ class VariableBuilder:
source.guard_source().is_nn_module()
or get_static_address_type(value) is not None
) and not source.guard_source().is_fsdp_module():
print("static")
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(
value, self.name, source=source

View File

@ -587,6 +587,7 @@ def compile_fx_inner(
]
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
print("--------------------------------------")
if not cudagraph_fail_reasons:
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
@ -958,6 +959,8 @@ def cudagraphify(
cudagraphify_fn: Callable[..., Any]
if config.triton.cudagraph_trees:
print("HI---------------------")
print(static_input_idxs)
cudagraphify_fn = functools.partial(
new_cudagraphify_impl,
device_index=device_index,

View File

@ -354,6 +354,9 @@ def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
def deferred_cudagraphify(inputs):
int_key = get_ints(inputs)
fn = fn_cache.get(int_key)
print("CACHED")
print(int_key)
print(fn)
if fn is not None:
return fn(inputs)
@ -364,6 +367,7 @@ def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
# first get indices we need to check to align, then update our static inputs,
# and finally copy
print(static_input_idxs)
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
copy_misaligned_inputs(inputs, check_input_idxs)
@ -643,6 +647,7 @@ class CUDAWarmupNode:
out_refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
print("ran warmed up node")
return out
@property
@ -1848,6 +1853,9 @@ class CUDAGraphTreeManager:
raise RuntimeError(f"Unknown node type {type(self.current_node)}")
def _run(self, new_inputs: List[Tensor], function_id: FunctionID):
print(self.path_state)
print(self.current_gen)
print(self.roots)
# we will try to end the current execution lazily, since
# we dont want to do unnecessary checking of the existing outputs
# on the hot path, but both recording and warmup only happen once
@ -1873,6 +1881,8 @@ class CUDAGraphTreeManager:
# then warm up graph B and make more allocations, the subsequent recording of A will not
# necessarily use the same addresses as in the warm up. Thus any warm up of a node can only
# be followed by warm up runs.
print(self.warmed_up_functions)
print(self.in_warmup)
if (
(
not (
@ -1888,6 +1898,7 @@ class CUDAGraphTreeManager:
if self.path_state == ExecutionState.EXECUTION:
self.apply_checkpoint_execution_state_in_allocator()
print("running eager")
return self.run_eager(new_inputs, function_id)
child_nodes = (
@ -2040,6 +2051,7 @@ class CUDAGraphTreeManager:
placeholders,
mutated_input_idxs,
)
print(self.ids_to_funcs)
self.id_to_mode[id] = mode
fn = functools.partial(self.run, function_id=id)
@ -2138,10 +2150,12 @@ class CUDAGraphTreeManager:
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.current_node = None
breakpoint()
return
if self.current_node.all_outputs_are_dead():
self.current_node = None
breakpoint()
return
self.check_warn_on_unable_to_start_executing(function_id)