[Optimus][fp8_activation_quantization] Only log when there's some node to be quantized (#158129)

Summary:
We add some extra check on whether there's some node has been marked as should quantize, otherwise we skip the quantizaton and tlparse log.

Rollback Plan:

Differential Revision: D78173788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158129
Approved by: https://github.com/Skylion007, https://github.com/avicizhu
This commit is contained in:
Menglu Yu
2025-07-15 19:22:26 +00:00
committed by PyTorch MergeBot
parent 5606c516fd
commit 4657a84bc5

View File

@ -684,47 +684,11 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1
def enable_activation_quantization(
saved_values: list[fx.Node],
def perform_fp8_activation_quantization(
fwd_module: fx.GraphModule,
bwd_module: fx.GraphModule,
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
bwd_module_inputs: dict[str, fx.Node],
) -> None:
if (
inductor_config.post_grad_fusion_options.get(
"activation_quantization_aten_pass", None
)
is None
):
return
static_input_names = (
[node.name for node in static_lifetime_input_nodes]
if static_lifetime_input_nodes
else []
)
saved_values_names = {node.name: node for node in saved_values}
if torch._inductor.config.post_grad_fusion_options[
"activation_quantization_aten_pass"
].get("exclude_primals", False):
saved_values_names = {
node.name: node for node in saved_values if "primals" not in node.name
}
fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
bwd_module_inputs = {
node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
}
for node in fwd_module_outputs:
if node.name in saved_values_names and should_quantize(node):
if node.name in static_input_names:
log.debug("Skipping quantization of static input %s: ", node.name)
continue
node.meta["saved_for_quantization"] = True
node.meta["dequant_type"] = node.meta["val"].dtype
# some of the fwd outputs and bwd inputs are not share the same object
bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -808,6 +772,53 @@ def enable_activation_quantization(
)
def enable_activation_quantization(
saved_values: list[fx.Node],
fwd_module: fx.GraphModule,
bwd_module: fx.GraphModule,
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
) -> None:
if (
inductor_config.post_grad_fusion_options.get(
"activation_quantization_aten_pass", None
)
is None
):
return
static_input_names = (
[node.name for node in static_lifetime_input_nodes]
if static_lifetime_input_nodes
else []
)
saved_values_names = {node.name: node for node in saved_values}
if torch._inductor.config.post_grad_fusion_options[
"activation_quantization_aten_pass"
].get("exclude_primals", False):
saved_values_names = {
node.name: node for node in saved_values if "primals" not in node.name
}
fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
bwd_module_inputs = {
node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
}
should_perform_fp8_quant = False
for node in fwd_module_outputs:
if node.name in saved_values_names and should_quantize(node):
if node.name in static_input_names:
log.debug("Skipping quantization of static input %s: ", node.name)
continue
node.meta["saved_for_quantization"] = True
node.meta["dequant_type"] = node.meta["val"].dtype
# some of the fwd outputs and bwd inputs are not share the same object
bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
should_perform_fp8_quant = True
if should_perform_fp8_quant:
perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs)
def _extract_fwd_bwd_modules(
joint_module: fx.GraphModule,
saved_values: list[fx.Node],