mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Optimus][fp8_activation_quantization] Only log when there's some node to be quantized (#158129)
Summary: We add some extra check on whether there's some node has been marked as should quantize, otherwise we skip the quantizaton and tlparse log. Rollback Plan: Differential Revision: D78173788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158129 Approved by: https://github.com/Skylion007, https://github.com/avicizhu
This commit is contained in:
committed by
PyTorch MergeBot
parent
5606c516fd
commit
4657a84bc5
@ -684,47 +684,11 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
||||
counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1
|
||||
|
||||
|
||||
def enable_activation_quantization(
|
||||
saved_values: list[fx.Node],
|
||||
def perform_fp8_activation_quantization(
|
||||
fwd_module: fx.GraphModule,
|
||||
bwd_module: fx.GraphModule,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
bwd_module_inputs: dict[str, fx.Node],
|
||||
) -> None:
|
||||
if (
|
||||
inductor_config.post_grad_fusion_options.get(
|
||||
"activation_quantization_aten_pass", None
|
||||
)
|
||||
is None
|
||||
):
|
||||
return
|
||||
|
||||
static_input_names = (
|
||||
[node.name for node in static_lifetime_input_nodes]
|
||||
if static_lifetime_input_nodes
|
||||
else []
|
||||
)
|
||||
saved_values_names = {node.name: node for node in saved_values}
|
||||
if torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("exclude_primals", False):
|
||||
saved_values_names = {
|
||||
node.name: node for node in saved_values if "primals" not in node.name
|
||||
}
|
||||
fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
bwd_module_inputs = {
|
||||
node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
|
||||
}
|
||||
for node in fwd_module_outputs:
|
||||
if node.name in saved_values_names and should_quantize(node):
|
||||
if node.name in static_input_names:
|
||||
log.debug("Skipping quantization of static input %s: ", node.name)
|
||||
continue
|
||||
node.meta["saved_for_quantization"] = True
|
||||
node.meta["dequant_type"] = node.meta["val"].dtype
|
||||
# some of the fwd outputs and bwd inputs are not share the same object
|
||||
bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
|
||||
bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -808,6 +772,53 @@ def enable_activation_quantization(
|
||||
)
|
||||
|
||||
|
||||
def enable_activation_quantization(
|
||||
saved_values: list[fx.Node],
|
||||
fwd_module: fx.GraphModule,
|
||||
bwd_module: fx.GraphModule,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
) -> None:
|
||||
if (
|
||||
inductor_config.post_grad_fusion_options.get(
|
||||
"activation_quantization_aten_pass", None
|
||||
)
|
||||
is None
|
||||
):
|
||||
return
|
||||
|
||||
static_input_names = (
|
||||
[node.name for node in static_lifetime_input_nodes]
|
||||
if static_lifetime_input_nodes
|
||||
else []
|
||||
)
|
||||
saved_values_names = {node.name: node for node in saved_values}
|
||||
if torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("exclude_primals", False):
|
||||
saved_values_names = {
|
||||
node.name: node for node in saved_values if "primals" not in node.name
|
||||
}
|
||||
fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
bwd_module_inputs = {
|
||||
node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
|
||||
}
|
||||
should_perform_fp8_quant = False
|
||||
for node in fwd_module_outputs:
|
||||
if node.name in saved_values_names and should_quantize(node):
|
||||
if node.name in static_input_names:
|
||||
log.debug("Skipping quantization of static input %s: ", node.name)
|
||||
continue
|
||||
node.meta["saved_for_quantization"] = True
|
||||
node.meta["dequant_type"] = node.meta["val"].dtype
|
||||
# some of the fwd outputs and bwd inputs are not share the same object
|
||||
bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
|
||||
bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
|
||||
should_perform_fp8_quant = True
|
||||
|
||||
if should_perform_fp8_quant:
|
||||
perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs)
|
||||
|
||||
|
||||
def _extract_fwd_bwd_modules(
|
||||
joint_module: fx.GraphModule,
|
||||
saved_values: list[fx.Node],
|
||||
|
Reference in New Issue
Block a user