mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[graph_manipulation] Set fused dtypes for all constant params/buffers (#77401)
Summary: We were handling constant attrs in a few different ways before, leading to confusion and missed handing for fused dtypes. This diff consolidates some of that code and unbreaks current breakage. Test Plan: CI. Recently broken tests now pass. Differential Revision: D36335238 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77401 Approved by: https://github.com/jaybean-dev, https://github.com/jamesr66a
This commit is contained in:
committed by
PyTorch MergeBot
parent
942f04172a
commit
18e36a6295
@ -123,7 +123,7 @@ class TestFXExperimental(JitTestCase):
|
||||
assert len(serialized_graph1["weights"]) == 4
|
||||
assert len(serialized_graph1["modules"]) == 0
|
||||
assert len(serialized_graph2["nodes"]) == 6
|
||||
assert len(serialized_graph2["weights"]) == 4
|
||||
assert len(serialized_graph2["weights"]) == 1
|
||||
assert len(serialized_graph2["modules"]) == 1
|
||||
assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]"
|
||||
assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32"
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
from typing import Dict, List, NamedTuple, Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node, Target, Argument, map_arg, map_aggregate
|
||||
from torch.fx.node import _get_qualified_name
|
||||
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def replace_target_nodes_with(
|
||||
@ -36,11 +35,13 @@ def replace_target_nodes_with(
|
||||
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
|
||||
fx_module.graph = new_graph
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class size_bytes(NamedTuple):
|
||||
output_size: int
|
||||
total_size: int
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_size_of_all_nodes(
|
||||
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
|
||||
@ -59,6 +60,7 @@ def get_size_of_all_nodes(
|
||||
node.size_bytes = get_size_of_node(fx_module, node)
|
||||
return
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_tensor_meta(node: Node) -> Any:
|
||||
tensor_meta = node.meta.get("tensor_meta")
|
||||
@ -71,6 +73,7 @@ def get_tensor_meta(node: Node) -> Any:
|
||||
|
||||
return tensor_meta
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
|
||||
"""Given a node with node.dtype and node.shape, return its total size and its output size.
|
||||
@ -102,14 +105,17 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
|
||||
output_size = size_per_elem_bytes * output_elem
|
||||
return size_bytes(output_size, total_size)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_shape(shape: torch.Size) -> str:
|
||||
return str(list(shape))
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_stride(stride: Tuple[int]) -> str:
|
||||
return str(list(stride))
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_tensor_quantization(
|
||||
tensor: torch.Tensor, weights: Dict, pcq_prefix: str
|
||||
@ -209,6 +215,7 @@ def serialize_tensor_quantization(
|
||||
scheme["q_per_channel_axis"] = tensor.q_per_channel_axis()
|
||||
return scheme, per_channel_dict
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
|
||||
weight_dict: Dict[str, Dict] = {name: {}}
|
||||
@ -227,6 +234,7 @@ def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
|
||||
|
||||
return weight_dict
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_leaf_module(
|
||||
node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str
|
||||
@ -244,6 +252,39 @@ def serialize_leaf_module(
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _update_weight_fused_dtypes(weight, name, node):
|
||||
"""
|
||||
For quantized embedding tables we need to update the shape/type, so we check if the
|
||||
users of this get_attr node is a quantized EB and this is the weight for the EB, and
|
||||
update the dtype accordingly.
|
||||
"""
|
||||
user_targets = {
|
||||
_get_qualified_name(n.target)
|
||||
.replace("fx2trt_oss.tracer.acc_tracer.", "")
|
||||
.replace("glow.fb.fx.", ""): n
|
||||
for n in node.users.keys()
|
||||
if n.op == "call_function"
|
||||
}
|
||||
if (
|
||||
"acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
|
||||
and str(
|
||||
user_targets["acc_ops.embedding_bag_byte_rowwise_offsets"].kwargs["weight"]
|
||||
)
|
||||
== name
|
||||
):
|
||||
weight[name]["dtype"] = "acc.uint8fused"
|
||||
# Same as above, but for the 4 bit version.
|
||||
if (
|
||||
"acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
|
||||
and str(
|
||||
user_targets["acc_ops.embedding_bag_4bit_rowwise_offsets"].kwargs["weight"]
|
||||
)
|
||||
== name
|
||||
):
|
||||
weight[name]["dtype"] = "acc.uint4fused"
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
|
||||
"""Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
|
||||
@ -291,17 +332,6 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
|
||||
submodules = dict(fx_module.named_modules())
|
||||
prefix = f"{name_prefix}." if name_prefix else ""
|
||||
|
||||
def add_weight_tensors(named_tensors):
|
||||
for name, p in named_tensors:
|
||||
if name.startswith("parent.") or not isinstance(p, torch.Tensor):
|
||||
continue
|
||||
weight_dict = serialize_weight(p, weights, prefix + name)
|
||||
serialized_dict["weights"].update(weight_dict)
|
||||
weights[prefix + name] = p
|
||||
|
||||
add_weight_tensors(fx_module.named_parameters())
|
||||
add_weight_tensors(fx_module.named_buffers())
|
||||
|
||||
def get_node_info(node):
|
||||
tensor_meta = get_tensor_meta(node)
|
||||
node_rep = {
|
||||
@ -373,58 +403,25 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
|
||||
if node.op == "get_attr":
|
||||
# If we are targeting a parent constant we update the target.
|
||||
if node.target.startswith("parent."):
|
||||
stripped_name = node.target[len("parent.") :]
|
||||
node.name = stripped_name
|
||||
node_rep["target"] = stripped_name
|
||||
weight = serialize_weight(
|
||||
weights[stripped_name], weights, node.target[len("parent.") :]
|
||||
)
|
||||
# For quantized embedding tables we need to update the shape/type,
|
||||
# so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
|
||||
user_targets = {
|
||||
_get_qualified_name(n.target)
|
||||
.replace("fx2trt_oss.tracer.acc_tracer.", "")
|
||||
.replace("glow.fb.fx.", ""): n
|
||||
for n in node.users.keys()
|
||||
}
|
||||
if (
|
||||
"acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
|
||||
and str(
|
||||
user_targets[
|
||||
"acc_ops.embedding_bag_byte_rowwise_offsets"
|
||||
].kwargs["weight"]
|
||||
)
|
||||
== stripped_name
|
||||
):
|
||||
weight[stripped_name]["dtype"] = "acc.uint8fused"
|
||||
# Same as above, but for the 4 bit version.
|
||||
if (
|
||||
"acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
|
||||
and str(
|
||||
user_targets[
|
||||
"acc_ops.embedding_bag_4bit_rowwise_offsets"
|
||||
].kwargs["weight"]
|
||||
)
|
||||
== stripped_name
|
||||
):
|
||||
weight[stripped_name]["dtype"] = "acc.uint4fused"
|
||||
|
||||
serialized_dict["weights"].update(weight)
|
||||
qualname = node.target[len("parent.") :]
|
||||
node.name = qualname
|
||||
node_rep["target"] = qualname
|
||||
else:
|
||||
# Find the actual target parameter/buffer from the fx_module.
|
||||
submod_path, _, target_name = node.target.rpartition(".")
|
||||
submod: Optional[torch.nn.Module] = (
|
||||
fx_module.get_submodule(submod_path) if submod_path else fx_module
|
||||
)
|
||||
assert submod is not None, f"submod {submod_path} not found"
|
||||
target = getattr(submod, target_name, None)
|
||||
assert target is not None, f"{target_name} not an attr of {submod_path}"
|
||||
qualname = prefix + node.target
|
||||
# Check that the target is a tensor, and that we haven't added it already from a leaf module.
|
||||
if isinstance(target, torch.Tensor) and qualname not in weights:
|
||||
weight = serialize_weight(target, weights, qualname)
|
||||
serialized_dict["weights"].update(weight)
|
||||
weights[qualname] = target
|
||||
# Find the actual target parameter/buffer from the fx_module.
|
||||
submod_path, _, target_name = node.target.rpartition(".")
|
||||
submod: Optional[torch.nn.Module] = (
|
||||
fx_module.get_submodule(submod_path) if submod_path else fx_module
|
||||
)
|
||||
assert submod is not None, f"submod {submod_path} not found"
|
||||
target = getattr(submod, target_name, None)
|
||||
assert target is not None, f"{target_name} not an attr of {submod_path}"
|
||||
# Check that the target is a tensor, and that we haven't added it already from a leaf module.
|
||||
if isinstance(target, torch.Tensor) and qualname not in weights:
|
||||
weight = serialize_weight(target, weights, qualname)
|
||||
_update_weight_fused_dtypes(weight, qualname, node)
|
||||
serialized_dict["weights"].update(weight)
|
||||
weights[qualname] = target
|
||||
|
||||
node_rep["op_code"] = node.op
|
||||
node_rep["name"] = node.name
|
||||
|
||||
Reference in New Issue
Block a user