[graph_manipulation] Set fused dtypes for all constant params/buffers (#77401)

Summary: We were handling constant attrs in a few different ways before, leading to confusion and missed handing for fused dtypes. This diff consolidates some of that code and unbreaks current breakage.

Test Plan: CI. Recently broken tests now pass.

Differential Revision: D36335238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77401
Approved by: https://github.com/jaybean-dev, https://github.com/jamesr66a
This commit is contained in:
Jordan Fix
2022-05-17 07:42:29 +00:00
committed by PyTorch MergeBot
parent 942f04172a
commit 18e36a6295
2 changed files with 62 additions and 65 deletions

View File

@ -123,7 +123,7 @@ class TestFXExperimental(JitTestCase):
assert len(serialized_graph1["weights"]) == 4
assert len(serialized_graph1["modules"]) == 0
assert len(serialized_graph2["nodes"]) == 6
assert len(serialized_graph2["weights"]) == 4
assert len(serialized_graph2["weights"]) == 1
assert len(serialized_graph2["modules"]) == 1
assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]"
assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32"

View File

@ -1,15 +1,14 @@
from typing import Dict, List, NamedTuple, Any, Optional, Tuple
import torch
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, Target, Argument, map_arg, map_aggregate
from torch.fx.node import _get_qualified_name
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def replace_target_nodes_with(
@ -36,11 +35,13 @@ def replace_target_nodes_with(
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
fx_module.graph = new_graph
@compatibility(is_backward_compatible=False)
class size_bytes(NamedTuple):
output_size: int
total_size: int
@compatibility(is_backward_compatible=False)
def get_size_of_all_nodes(
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
@ -59,6 +60,7 @@ def get_size_of_all_nodes(
node.size_bytes = get_size_of_node(fx_module, node)
return
@compatibility(is_backward_compatible=False)
def get_tensor_meta(node: Node) -> Any:
tensor_meta = node.meta.get("tensor_meta")
@ -71,6 +73,7 @@ def get_tensor_meta(node: Node) -> Any:
return tensor_meta
@compatibility(is_backward_compatible=False)
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
"""Given a node with node.dtype and node.shape, return its total size and its output size.
@ -102,14 +105,17 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
output_size = size_per_elem_bytes * output_elem
return size_bytes(output_size, total_size)
@compatibility(is_backward_compatible=False)
def serialize_shape(shape: torch.Size) -> str:
return str(list(shape))
@compatibility(is_backward_compatible=False)
def serialize_stride(stride: Tuple[int]) -> str:
return str(list(stride))
@compatibility(is_backward_compatible=False)
def serialize_tensor_quantization(
tensor: torch.Tensor, weights: Dict, pcq_prefix: str
@ -209,6 +215,7 @@ def serialize_tensor_quantization(
scheme["q_per_channel_axis"] = tensor.q_per_channel_axis()
return scheme, per_channel_dict
@compatibility(is_backward_compatible=False)
def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
weight_dict: Dict[str, Dict] = {name: {}}
@ -227,6 +234,7 @@ def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
return weight_dict
@compatibility(is_backward_compatible=False)
def serialize_leaf_module(
node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str
@ -244,6 +252,39 @@ def serialize_leaf_module(
return parameters
def _update_weight_fused_dtypes(weight, name, node):
"""
For quantized embedding tables we need to update the shape/type, so we check if the
users of this get_attr node is a quantized EB and this is the weight for the EB, and
update the dtype accordingly.
"""
user_targets = {
_get_qualified_name(n.target)
.replace("fx2trt_oss.tracer.acc_tracer.", "")
.replace("glow.fb.fx.", ""): n
for n in node.users.keys()
if n.op == "call_function"
}
if (
"acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
and str(
user_targets["acc_ops.embedding_bag_byte_rowwise_offsets"].kwargs["weight"]
)
== name
):
weight[name]["dtype"] = "acc.uint8fused"
# Same as above, but for the 4 bit version.
if (
"acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
and str(
user_targets["acc_ops.embedding_bag_4bit_rowwise_offsets"].kwargs["weight"]
)
== name
):
weight[name]["dtype"] = "acc.uint4fused"
@compatibility(is_backward_compatible=False)
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
"""Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
@ -291,17 +332,6 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
submodules = dict(fx_module.named_modules())
prefix = f"{name_prefix}." if name_prefix else ""
def add_weight_tensors(named_tensors):
for name, p in named_tensors:
if name.startswith("parent.") or not isinstance(p, torch.Tensor):
continue
weight_dict = serialize_weight(p, weights, prefix + name)
serialized_dict["weights"].update(weight_dict)
weights[prefix + name] = p
add_weight_tensors(fx_module.named_parameters())
add_weight_tensors(fx_module.named_buffers())
def get_node_info(node):
tensor_meta = get_tensor_meta(node)
node_rep = {
@ -373,58 +403,25 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
if node.op == "get_attr":
# If we are targeting a parent constant we update the target.
if node.target.startswith("parent."):
stripped_name = node.target[len("parent.") :]
node.name = stripped_name
node_rep["target"] = stripped_name
weight = serialize_weight(
weights[stripped_name], weights, node.target[len("parent.") :]
)
# For quantized embedding tables we need to update the shape/type,
# so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
user_targets = {
_get_qualified_name(n.target)
.replace("fx2trt_oss.tracer.acc_tracer.", "")
.replace("glow.fb.fx.", ""): n
for n in node.users.keys()
}
if (
"acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
and str(
user_targets[
"acc_ops.embedding_bag_byte_rowwise_offsets"
].kwargs["weight"]
)
== stripped_name
):
weight[stripped_name]["dtype"] = "acc.uint8fused"
# Same as above, but for the 4 bit version.
if (
"acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
and str(
user_targets[
"acc_ops.embedding_bag_4bit_rowwise_offsets"
].kwargs["weight"]
)
== stripped_name
):
weight[stripped_name]["dtype"] = "acc.uint4fused"
serialized_dict["weights"].update(weight)
qualname = node.target[len("parent.") :]
node.name = qualname
node_rep["target"] = qualname
else:
# Find the actual target parameter/buffer from the fx_module.
submod_path, _, target_name = node.target.rpartition(".")
submod: Optional[torch.nn.Module] = (
fx_module.get_submodule(submod_path) if submod_path else fx_module
)
assert submod is not None, f"submod {submod_path} not found"
target = getattr(submod, target_name, None)
assert target is not None, f"{target_name} not an attr of {submod_path}"
qualname = prefix + node.target
# Check that the target is a tensor, and that we haven't added it already from a leaf module.
if isinstance(target, torch.Tensor) and qualname not in weights:
weight = serialize_weight(target, weights, qualname)
serialized_dict["weights"].update(weight)
weights[qualname] = target
# Find the actual target parameter/buffer from the fx_module.
submod_path, _, target_name = node.target.rpartition(".")
submod: Optional[torch.nn.Module] = (
fx_module.get_submodule(submod_path) if submod_path else fx_module
)
assert submod is not None, f"submod {submod_path} not found"
target = getattr(submod, target_name, None)
assert target is not None, f"{target_name} not an attr of {submod_path}"
# Check that the target is a tensor, and that we haven't added it already from a leaf module.
if isinstance(target, torch.Tensor) and qualname not in weights:
weight = serialize_weight(target, weights, qualname)
_update_weight_fused_dtypes(weight, qualname, node)
serialized_dict["weights"].update(weight)
weights[qualname] = target
node_rep["op_code"] = node.op
node_rep["name"] = node.name