Compare commits

...

2 Commits

Author SHA1 Message Date
af79feb29d Merge branch 'main' into bf/cudagraph-partition 2025-01-09 09:51:28 -08:00
71f2d29ebd init 2024-12-09 10:10:50 -08:00
3 changed files with 57 additions and 7 deletions

View File

@ -647,7 +647,7 @@ class PythonWrapperCodegen(CodeGen):
Generate outer wrapper in Python that calls the kernels.
"""
def __init__(self):
def __init__(self, is_subgraph=False):
super().__init__()
self._names_iter: Iterator[int] = count()
self.imports = IndentedBuffer()
@ -687,6 +687,8 @@ class PythonWrapperCodegen(CodeGen):
self.codegened_graph_stack = []
self.computed_sizes_stack = []
self.is_subgraph = is_subgraph
# breakpoint()
self.write_header()
self.write_prefix()
self.write_kernel_autotune_defs_header()
@ -731,6 +733,7 @@ class PythonWrapperCodegen(CodeGen):
def create(
is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen
):
# breakpoint()
if is_subgraph:
return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper)
return PythonWrapperCodegen()
@ -875,6 +878,7 @@ class PythonWrapperCodegen(CodeGen):
return
def codegen_input_size_asserts(self) -> None:
# breakpoint()
for name, buf in V.graph.graph_inputs.items():
if isinstance(buf, sympy.Expr):
continue
@ -887,6 +891,7 @@ class PythonWrapperCodegen(CodeGen):
self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
def codegen_input_nan_asserts(self) -> None:
# breakpoint()
self.prefix.writeline("# make sure graph inputs are not nan/inf")
for name, buf in V.graph.graph_inputs.items():
if isinstance(buf, sympy.Expr):
@ -906,7 +911,7 @@ class PythonWrapperCodegen(CodeGen):
"""
)
def write_prefix(self) -> None:
def write_prefix_for_subgraph(self) -> None:
assert self.launcher_fn_name is not None
self.write_async_compile_wait()
self.prefix.splice(
@ -932,6 +937,43 @@ class PythonWrapperCodegen(CodeGen):
self.codegen_inputs()
self.codegen_input_size_and_nan_asserts()
def write_prefix(self) -> None:
# breakpoint()
if self.is_subgraph:
self.write_prefix_for_subgraph()
return
assert self.launcher_fn_name is not None
self.write_async_compile_wait()
self.prefix.splice(
"""
class Runner:
def __init__(self):
pass
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.callables):
new_callables.append(fn(c))
self.callables = new_callables
def call(self, args):
"""
)
with self.prefix.indent(2):
if config.triton.debug_sync_graph:
self.prefix.writeline(V.graph.device_ops.synchronize())
if V.graph.graph_inputs:
lhs = ", ".join(V.graph.graph_input_names)
if len(V.graph.graph_input_names) == 1:
lhs += ","
self.prefix.writeline(f"{lhs} = args")
self.prefix.writeline("args.clear()")
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
self.codegen_input_size_and_nan_asserts()
def codegen_input_size_and_nan_asserts(self) -> None:
if config.size_asserts:
self.codegen_input_size_asserts()
@ -1187,7 +1229,7 @@ class PythonWrapperCodegen(CodeGen):
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
self.generate_reset_kernel_saved_flags()
breakpoint()
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen(self.wrapper_call)
@ -1217,12 +1259,19 @@ class PythonWrapperCodegen(CodeGen):
self.finalize_prefix()
result.splice(self.prefix)
with result.indent():
with result.indent(2):
result.splice(self.wrapper_call)
self.generate_before_suffix(result)
result.splice(self.suffix)
if not self.is_subgraph:
result.splice(
"""
runner = Runner()
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns
"""
)
self.generate_end(result)
self.add_benchmark_harness(result)
@ -2510,7 +2559,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# because __init__ calls set_launcher_fn_name
self.subgraph_name = subgraph_name
self.parent_wrapper = parent_wrapper
super().__init__()
super().__init__(is_subgraph=True)
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of

View File

@ -984,6 +984,7 @@ class _InProcessFxCompile(FxCompile):
)
metrics_helper = metrics.CachedMetricsHelper()
with V.set_graph_handler(graph):
breakpoint()
graph.run(*example_inputs)
output_strides: List[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:

View File

@ -3734,7 +3734,7 @@ class Scheduler:
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
self.buffer_names_to_free.update(node.last_usage)
# breakpoint()
if node.is_template():
prologue, template_node, epilogue = node.get_prologue_template_epilogue(
list(node.get_nodes())