mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: <img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" /> and without this op: <img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" /> After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157133 Approved by: https://github.com/jansel
250 lines
7.9 KiB
Python
250 lines
7.9 KiB
Python
# flake8: noqa: B950
|
|
from ._internal import register_artifact, register_log
|
|
|
|
|
|
DYNAMIC = [
|
|
"torch.fx.experimental.symbolic_shapes",
|
|
"torch.fx.experimental.sym_node",
|
|
"torch.fx.experimental.recording",
|
|
]
|
|
DISTRIBUTED = [
|
|
"torch.distributed",
|
|
"torch._dynamo.backends.distributed",
|
|
"torch.nn.parallel.distributed",
|
|
]
|
|
|
|
register_log(
|
|
"async_compile",
|
|
[
|
|
"torch._inductor.async_compile",
|
|
"torch._inductor.compile_worker.tracked_process_pool",
|
|
],
|
|
)
|
|
register_log(
|
|
"cache", ("torch._inductor.remote_cache", "torch._inductor.fb.remote_cache")
|
|
)
|
|
register_log("dynamo", ["torch._dynamo", *DYNAMIC])
|
|
register_log("fake_tensor", ["torch._subclasses.fake_tensor"])
|
|
register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
|
|
register_log("autograd", "torch.autograd")
|
|
register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"])
|
|
|
|
register_artifact(
|
|
"cudagraphs",
|
|
"Logs information from wrapping inductor generated code with cudagraphs.",
|
|
)
|
|
|
|
register_log("dynamic", DYNAMIC)
|
|
register_log("torch", "torch")
|
|
register_log("distributed", DISTRIBUTED)
|
|
register_log(
|
|
"c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"]
|
|
)
|
|
register_log(
|
|
"ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"]
|
|
)
|
|
register_log("pp", ["torch.distributed.pipelining"])
|
|
register_log("fsdp", ["torch.distributed.fsdp", "torch.distributed._composable.fsdp"])
|
|
register_log("dtensor", ["torch.distributed._tensor", "torch.distributed.tensor"])
|
|
register_log("onnx", "torch.onnx")
|
|
register_log(
|
|
"export",
|
|
[
|
|
"torch._dynamo",
|
|
"torch.export",
|
|
"torch.export.dynamic_shapes",
|
|
*DYNAMIC,
|
|
"torch._export.converter",
|
|
"torch._export.non_strict_utils",
|
|
"torch._export.serde.serialize",
|
|
"torch.fx.experimental.proxy_tensor",
|
|
],
|
|
)
|
|
|
|
register_artifact(
|
|
"guards",
|
|
"This prints the guards for every compiled Dynamo frame. It does not tell you where the guards come from.",
|
|
visible=True,
|
|
)
|
|
register_artifact("verbose_guards", "", off_by_default=True)
|
|
register_artifact(
|
|
"bytecode",
|
|
"Prints the original and modified bytecode from Dynamo. Mostly useful if you're debugging our bytecode generation in Dynamo.",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"graph",
|
|
"Prints the dynamo traced graph (prior to AOTDispatch) in a table. If you prefer python code use `graph_code` instead. ",
|
|
)
|
|
register_artifact("graph_code", "Like `graph`, but gives you the Python code instead.")
|
|
register_artifact(
|
|
"graph_code_verbose",
|
|
"Verbose FX pass logs, e.g. from tensorify_python_scalars and runtime_assert.",
|
|
)
|
|
register_artifact(
|
|
"graph_sizes", "Prints the sizes of all FX nodes in the dynamo graph."
|
|
)
|
|
register_artifact(
|
|
"trace_source",
|
|
"As we execute bytecode, prints the file name / line number we are processing and the actual source code. Useful with `bytecode`",
|
|
)
|
|
register_artifact(
|
|
"trace_call",
|
|
"Like trace_source, but it will give you the per-expression blow-by-blow if your Python is recent enough.",
|
|
)
|
|
register_artifact(
|
|
"trace_bytecode",
|
|
"As we trace bytecode, prints the instruction and the current stack.",
|
|
)
|
|
register_artifact(
|
|
"aot_graphs",
|
|
"Prints the FX forward and backward graph generated by AOTDispatch, after partitioning. Useful to understand what's being given to Inductor",
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"aot_joint_graph",
|
|
"Print FX joint graph from AOTAutograd, prior to partitioning. Useful for debugging partitioning",
|
|
)
|
|
register_artifact(
|
|
"aot_graphs_effects",
|
|
"Prints the FX forward and backward graph generated by AOTDispatch, useful for debugging effects processing.",
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"pre_grad_graphs",
|
|
"Prints the FX graph before inductor pre grad passes. Useful to understand what's being given to Inductor before grad passes",
|
|
)
|
|
register_artifact(
|
|
"post_grad_graphs",
|
|
"Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes",
|
|
)
|
|
register_artifact(
|
|
"ir_pre_fusion",
|
|
"Prints the IR before inductor fusion passes.",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"ir_post_fusion",
|
|
"Prints the IR after inductor fusion passes.",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"compiled_autograd",
|
|
"Prints various logs in compiled_autograd, including but not limited to the graphs. Useful for debugging compiled_autograd.",
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"compiled_autograd_verbose",
|
|
"Will affect performance. Prints compiled_autograd logs with C++ info e.g. autograd node -> fx node mapping",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"ddp_graphs",
|
|
"Only relevant for compiling DDP. DDP splits into multiple graphs to trigger comms early. This will print each individual graph here.",
|
|
)
|
|
register_artifact(
|
|
"recompiles",
|
|
"Prints the reason why we recompiled a graph. Very, very useful.",
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"recompiles_verbose",
|
|
"Prints all guard checks that fail during a recompilation. "
|
|
"At runtime, Dynamo will stop at the first failed check for each failing guard. "
|
|
"So not all logged failing checks are actually ran by Dynamo.",
|
|
visible=True,
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"graph_breaks",
|
|
"Prints whenever Dynamo decides that it needs to graph break (i.e. create a new graph). Useful for debugging why torch.compile has poor performance",
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"not_implemented",
|
|
"Prints log messages whenever we return NotImplemented in a multi-dispatch, letting you trace through each object we attempted to dispatch to",
|
|
)
|
|
register_artifact(
|
|
"output_code",
|
|
"Prints the code that Inductor generates (either Triton or C++)",
|
|
off_by_default=True,
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"kernel_code",
|
|
"Prints the code that Inductor generates (on a per-kernel basis)",
|
|
off_by_default=True,
|
|
visible=True,
|
|
)
|
|
register_artifact(
|
|
"schedule",
|
|
"Inductor scheduler information. Useful if working on Inductor fusion algo",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact("perf_hints", "", off_by_default=True)
|
|
register_artifact("onnx_diagnostics", "", off_by_default=True)
|
|
register_artifact("compute_dependencies", "", off_by_default=True)
|
|
register_artifact(
|
|
"fusion",
|
|
"Detailed Inductor fusion decisions. More detailed than 'schedule'",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"loop_ordering",
|
|
"Logs related to loop ordering",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"loop_tiling",
|
|
"Logs related to loop ordering",
|
|
off_by_default=True,
|
|
)
|
|
|
|
register_artifact(
|
|
"overlap",
|
|
"Detailed Inductor compute/comm overlap decisions",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"sym_node",
|
|
"Logs extra info for various SymNode operations",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"trace_shape_events",
|
|
"Logs traces for every ShapeEnv operation that we record for replay",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"cudagraph_static_inputs",
|
|
"Logs static inputs handling in dynamo, AOT, and cudagraphs",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"benchmarking",
|
|
"Detailed Inductor benchmarking information.",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"autotuning",
|
|
"Autotuning choice logs, such as kernel source, perf, and tuning parameters.",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"graph_region_expansion",
|
|
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",
|
|
off_by_default=True,
|
|
)
|
|
|
|
register_artifact(
|
|
"inductor_metrics",
|
|
"Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact(
|
|
"hierarchical_compile",
|
|
"Logs debug info for hierarchical compilation",
|
|
off_by_default=True,
|
|
)
|
|
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|