mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65639 This op is used by mobilenet v2. Test Plan: buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh buck test glow/fb/fx/acc_tracer:test_acc_shape_inference -- hardtanh buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh Reviewed By: yinghai Differential Revision: D31184297 fbshipit-source-id: 5a04319f6d16fb930372442616e27211107ecc67
1492 lines
45 KiB
Python
1492 lines
45 KiB
Python
# encoding: utf-8
|
|
import operator
|
|
|
|
import torch # isort:skip
|
|
from typing import Sequence, Optional, List, cast
|
|
|
|
import torch.fx.experimental.fx_acc.acc_utils as acc_utils
|
|
import torch.nn as nn
|
|
from torch.fx.experimental.fx_acc.acc_normalizer import (
|
|
register_acc_op,
|
|
register_acc_op_mapping,
|
|
register_custom_acc_mapper_fn,
|
|
)
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
|
|
|
this_arg_is_optional = True
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.linear))
|
|
@register_acc_op
|
|
def linear(*, input, weight, bias):
|
|
return nn.functional.linear(**locals())
|
|
|
|
|
|
@register_acc_op
|
|
def quantized_linear(*, input, weight, bias, acc_out_ty=None):
|
|
assert acc_out_ty is not None
|
|
return nn.quantized.functional.linear(
|
|
input,
|
|
weight,
|
|
bias,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_method", "flatten"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("start_dim", "start_dim", this_arg_is_optional),
|
|
("end_dim", "end_dim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.flatten))
|
|
@register_acc_op
|
|
def flatten(*, input, start_dim=0, end_dim=-1):
|
|
return torch.flatten(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_method", "squeeze"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.squeeze),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def squeeze(*, input, dim=None):
|
|
if dim is None:
|
|
return input.squeeze()
|
|
return input.squeeze(dim=dim)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d))
|
|
@register_acc_op
|
|
def max_pool2d(
|
|
*, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices
|
|
):
|
|
return nn.functional.max_pool2d(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", nn.functional.adaptive_avg_pool2d)
|
|
)
|
|
@register_acc_op
|
|
def adaptive_avg_pool2d(*, input, output_size):
|
|
return nn.functional.adaptive_avg_pool2d(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool2d))
|
|
@register_acc_op
|
|
def avg_pool2d(
|
|
*,
|
|
input,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
ceil_mode,
|
|
count_include_pad,
|
|
divisor_override,
|
|
):
|
|
return nn.functional.avg_pool2d(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.sign))
|
|
@register_acc_op
|
|
def sign(*, input):
|
|
return torch.sign(input)
|
|
|
|
|
|
@register_acc_op
|
|
def size(*, input):
|
|
return input.size()
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", getattr),
|
|
arg_replacement_tuples=[],
|
|
)
|
|
def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Custom function for mapping a call_function getattr to other ops. Currently only
|
|
supports loading a getattr called on a torch.Tensor with attr name "shape", which is
|
|
supported by mapping it to acc_ops.size().
|
|
"""
|
|
# Have to use args here since getattr forces positional args.
|
|
input_obj = node.args[0]
|
|
attr_name = node.args[1]
|
|
assert isinstance(input_obj, torch.fx.Node)
|
|
assert (
|
|
input_obj.meta["type"] == torch.Tensor
|
|
), f"Expected torch.Tensor type for {input_obj.meta['type']}"
|
|
assert (
|
|
attr_name == "shape"
|
|
), f"Only supporting shape getattr for now, not {attr_name}"
|
|
with node.graph.inserting_before(node):
|
|
size_node = node.graph.call_function(size, kwargs={"input": input_obj})
|
|
size_node.meta = node.meta.copy()
|
|
return size_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "size"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim", this_arg_is_optional),
|
|
],
|
|
)
|
|
def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Mapping from Tensor.size() to acc_ops.size. We map size() to acc_ops.size directly
|
|
and map size(dim) to acc_ops.size + acc_ops.getitem.
|
|
"""
|
|
|
|
with node.graph.inserting_before(node):
|
|
size_node = node.graph.call_function(
|
|
size, kwargs={"input": node.kwargs["input"]}
|
|
)
|
|
|
|
if "dim" not in node.kwargs:
|
|
size_node.meta = node.meta.copy()
|
|
return size_node
|
|
|
|
size_node.meta["type"] = torch.Size
|
|
getitem_node = node.graph.call_function(
|
|
getitem, kwargs={"input": size_node, "idx": node.kwargs["dim"]}
|
|
)
|
|
getitem_node.meta = node.meta.copy()
|
|
return getitem_node
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", operator.add))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "add"))
|
|
@register_acc_op
|
|
def add(*, input, other):
|
|
return input + other
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_method", "unsqueeze"))
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze))
|
|
@register_acc_op
|
|
def unsqueeze(*, input, dim):
|
|
return torch.unsqueeze(**locals())
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.stack),
|
|
arg_replacement_tuples=[
|
|
("tensors", "tensors"),
|
|
("dim", "dim"),
|
|
],
|
|
)
|
|
def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Map torch.stack to unsqueeze + cat.
|
|
"""
|
|
with node.graph.inserting_before(node):
|
|
inputs = node.kwargs["tensors"]
|
|
unsqueeze_nodes = []
|
|
assert isinstance(inputs, Sequence)
|
|
for i, t in enumerate(inputs):
|
|
new_node = node.graph.create_node(
|
|
"call_function",
|
|
unsqueeze,
|
|
kwargs={"input": t, "dim": node.kwargs["dim"]},
|
|
name=f"{node.name}_unsqueeze_{i}",
|
|
)
|
|
new_node.meta["type"] = torch.Tensor
|
|
unsqueeze_nodes.append(new_node)
|
|
cat_node = node.graph.create_node(
|
|
"call_function",
|
|
cat,
|
|
kwargs={"tensors": unsqueeze_nodes, "dim": node.kwargs["dim"]},
|
|
)
|
|
cat_node.meta = node.meta.copy()
|
|
return cat_node
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
|
|
@register_acc_op
|
|
def clamp(*, input, min, max):
|
|
return torch.clamp(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.cat))
|
|
@register_acc_op
|
|
def cat(*, tensors, dim):
|
|
return torch.cat(**locals())
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.transpose),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim0", "dim0"),
|
|
("dim1", "dim1"),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "transpose"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim0", "dim0"),
|
|
("dim1", "dim1"),
|
|
],
|
|
)
|
|
def transpose_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
# Get the dim-permutation/shuffle
|
|
shape_as_list = node.meta["tensor_meta"].shape
|
|
ranks = len(shape_as_list)
|
|
shuffle = list(i for i in range(ranks))
|
|
dim0 = cast(int, node.kwargs["dim0"])
|
|
dim1 = cast(int, node.kwargs["dim1"])
|
|
shuffle[dim0] = dim1
|
|
shuffle[dim1] = dim0
|
|
|
|
# Create the new acc_ops.permute node. Update all uses of the transpose
|
|
# node and then delete the transpose node.
|
|
with node.graph.inserting_after(node):
|
|
permute_node = node.graph.call_function(
|
|
the_function=permute,
|
|
kwargs={
|
|
"input": node.kwargs.get("input"),
|
|
"permutation": shuffle,
|
|
},
|
|
)
|
|
permute_node.meta = node.meta.copy()
|
|
node.replace_all_uses_with(permute_node)
|
|
|
|
permute_node.graph.erase_node(node)
|
|
return permute_node
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_method", "contiguous"))
|
|
@register_acc_op
|
|
def contiguous(*, input):
|
|
return input.contiguous()
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax))
|
|
@register_acc_op
|
|
def softmax(*, input, dim, dtype):
|
|
"""
|
|
_stacklevel are ignored here.
|
|
"""
|
|
return torch.nn.functional.softmax(**locals())
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.addmm),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("mat1", "mat1"),
|
|
("mat2", "mat2"),
|
|
("beta", "beta"),
|
|
("alpha", "alpha"),
|
|
],
|
|
)
|
|
def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1
|
|
then we also insert acc_ops.mul to the right place.
|
|
"""
|
|
with node.graph.inserting_before(node):
|
|
mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]}
|
|
mm_node = node.graph.create_node(
|
|
"call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm"
|
|
)
|
|
mm_node.meta = node.meta.copy()
|
|
|
|
if node.kwargs["alpha"] != 1:
|
|
mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]}
|
|
mm_node = node.graph.create_node(
|
|
"call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul"
|
|
)
|
|
mm_node.meta = node.meta.copy()
|
|
|
|
input_node = node.kwargs["input"]
|
|
if node.kwargs["beta"] != 1:
|
|
mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]}
|
|
new_input_node = node.graph.create_node(
|
|
"call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul"
|
|
)
|
|
assert isinstance(input_node, torch.fx.Node)
|
|
new_input_node.meta = input_node.meta.copy()
|
|
input_node = new_input_node
|
|
|
|
add_kwargs = {"input": mm_node, "other": input_node}
|
|
add_node = node.graph.create_node(
|
|
"call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"
|
|
)
|
|
add_node.meta = node.meta.copy()
|
|
return add_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.t),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "t"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def t_mapper(node: torch.fx.Node, _: nn.Module):
|
|
ranks = len(node.meta["tensor_meta"].shape)
|
|
shuffle = [1, 0] if (ranks > 1) else [0]
|
|
|
|
with node.graph.inserting_before(node):
|
|
new_node = node.graph.create_node(
|
|
"call_function",
|
|
permute,
|
|
kwargs={"input": node.kwargs["input"], "permutation": shuffle},
|
|
)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_method", "permute"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("*", "permutation"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def permute(*, input, permutation):
|
|
return input.permute(*permutation)
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.square),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
input_node = node.kwargs["input"]
|
|
with node.graph.inserting_before(node):
|
|
new_node = node.graph.call_function(
|
|
mul, kwargs={"input": input_node, "other": input_node}
|
|
)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.bmm),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("mat2", "other"),
|
|
],
|
|
)
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.matmul))
|
|
@register_acc_op
|
|
def matmul(*, input, other):
|
|
return torch.matmul(**locals())
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", nn.functional.dropout),
|
|
arg_replacement_tuples=[("input", "input")])
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "detach"),
|
|
arg_replacement_tuples=[("input", "input")])
|
|
def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
|
|
"""
|
|
Remove dropout node and directly map its input to output.
|
|
"""
|
|
return node.kwargs["input"]
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", nn.functional.hardtanh),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("min_val", "left"),
|
|
("max_val", "right"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def hardtanh(*, input, left, right):
|
|
return nn.functional.hardtanh(input, min_val=left, max_val=right)
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", nn.functional.hardsigmoid))
|
|
@register_acc_op
|
|
def hardsigmoid(*, input):
|
|
return nn.functional.hardsigmoid(input)
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", nn.functional.hardswish),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def hardswish_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
input_node = node.kwargs["input"]
|
|
with node.graph.inserting_before(node):
|
|
new_sigmoid_node = node.graph.call_function(
|
|
hardsigmoid, kwargs={"input": input_node}
|
|
)
|
|
new_sigmoid_node.meta = node.meta.copy()
|
|
new_node = node.graph.call_function(
|
|
mul, kwargs={"input": new_sigmoid_node, "other": input_node}
|
|
)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.ops.quantized.add),
|
|
arg_replacement_tuples=[
|
|
("qa", "input"),
|
|
("qb", "other"),
|
|
("scale", "scale"),
|
|
("zero_point", "zero_point"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[
|
|
("scale", "q_scale"),
|
|
("zero_point", "q_zero_point"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def quantized_add(*, input, other, acc_out_ty=None):
|
|
assert acc_out_ty is not None
|
|
return torch.ops.quantized.add(
|
|
input,
|
|
other,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.ops.quantized.mul),
|
|
arg_replacement_tuples=[
|
|
("qa", "input"),
|
|
("qb", "other"),
|
|
("scale", "scale"),
|
|
("zero_point", "zero_point"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[
|
|
("scale", "q_scale"),
|
|
("zero_point", "q_zero_point"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def quantized_mul(*, input, other, acc_out_ty=None):
|
|
assert acc_out_ty is not None
|
|
return torch.ops.quantized.mul(
|
|
input,
|
|
other,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.quantize_per_tensor),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("scale", "scale"),
|
|
("zero_point", "zero_point"),
|
|
("dtype", "dtype"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[
|
|
("dtype", "dtype"),
|
|
("scale", "q_scale"),
|
|
("zero_point", "q_zero_point"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def quantize_per_tensor(*, input, acc_out_ty=None):
|
|
assert acc_out_ty is not None
|
|
return torch.quantize_per_tensor(
|
|
input,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_method", "dequantize"))
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.dequantize))
|
|
@register_acc_op
|
|
def dequantize(*, input):
|
|
return torch.dequantize(input)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", operator.sub))
|
|
@register_acc_op
|
|
def sub(*, input, other):
|
|
return input - other
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.mul))
|
|
@register_acc_op_mapping(op_and_target=("call_function", operator.mul))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "mul"))
|
|
@register_acc_op
|
|
def mul(*, input, other):
|
|
return input * other
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
|
|
@register_acc_op
|
|
def div(*, input, other):
|
|
return input / other
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.pow))
|
|
@register_acc_op
|
|
def pow(*, input, exponent):
|
|
return torch.pow(input, exponent)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu))
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.relu),
|
|
arg_replacement_tuples=[("input", "input")],
|
|
)
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_method", "relu"),
|
|
arg_replacement_tuples=[("input", "input")],
|
|
)
|
|
@register_acc_op
|
|
def relu(*, input, inplace=False):
|
|
return nn.functional.relu(**locals())
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.log1p),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
|
|
with node.graph.inserting_before(node):
|
|
add_kwargs = {"input": node.kwargs["input"], "other": 1.0}
|
|
add_node = node.graph.call_function(add, kwargs=add_kwargs)
|
|
add_node.meta = node.meta.copy()
|
|
log_kwargs = {"input": add_node}
|
|
log_node = node.graph.call_function(log, kwargs=log_kwargs)
|
|
log_node.meta = node.meta.copy()
|
|
return log_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "sum"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
("dtype", "dtype", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.sum),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
("dtype", "dtype", this_arg_is_optional),
|
|
],
|
|
)
|
|
def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
|
|
with node.graph.inserting_before(node):
|
|
sum_kwargs = dict(node.kwargs)
|
|
if "dim" in sum_kwargs and isinstance(sum_kwargs["dim"], int):
|
|
sum_kwargs["dim"] = (sum_kwargs["dim"],)
|
|
sum_node = node.graph.call_function(sum, kwargs=sum_kwargs)
|
|
sum_node.meta = node.meta.copy()
|
|
return sum_node
|
|
|
|
|
|
@register_acc_op
|
|
def sum(*, input, dim=None, keepdim=False, dtype=None):
|
|
if dim is not None:
|
|
return torch.sum(**locals())
|
|
else:
|
|
return input.sum(dtype=dtype)
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "max"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
(("dim", "other"), "dim_or_other", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.max),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
(("dim", "other"), "dim_or_other", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "min"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
(("dim", "other"), "dim_or_other", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.min),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
(("dim", "other"), "dim_or_other", this_arg_is_optional),
|
|
("keepdim", "keepdim", this_arg_is_optional),
|
|
],
|
|
)
|
|
def add_maximum_minimum_mapper(
|
|
node: torch.fx.Node, mod: torch.fx.GraphModule
|
|
) -> torch.fx.Node:
|
|
# there are effectively three versions of torch.max / torch.min
|
|
# full reduce: torch.max(input) -> Tensor
|
|
# dimensional reduce: torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
|
|
# elementwise: torch.max(input, other, *, out=None) -> Tensor
|
|
|
|
# the mapper function is remapping for both min and max situations
|
|
# this helper function makes the choices available clearer and provides an easier way
|
|
# to lookup the right function
|
|
def target_map(op, target):
|
|
if (op, target) in (("call_method", "max"), ("call_function", torch.max)):
|
|
return dict(
|
|
full_reduce=max_full_reduce,
|
|
dim_reduce=max_dim_reduce,
|
|
elementwise=maximum,
|
|
)
|
|
elif (op, target) in (("call_method", "min"), ("call_function", torch.min)):
|
|
return dict(
|
|
full_reduce=min_full_reduce,
|
|
dim_reduce=min_dim_reduce,
|
|
elementwise=minimum,
|
|
)
|
|
|
|
with node.graph.inserting_before(node):
|
|
new_targets = target_map(node.op, node.target)
|
|
max_kwargs = dict()
|
|
max_kwargs["input"] = node.kwargs["input"]
|
|
if ("dim_or_other" not in node.kwargs) or (node.kwargs["dim_or_other"] is None):
|
|
nt = new_targets["full_reduce"]
|
|
max_node = node.graph.call_function(nt, kwargs=max_kwargs)
|
|
elif isinstance(node.kwargs["dim_or_other"], int):
|
|
nt = new_targets["dim_reduce"]
|
|
dim = node.kwargs["dim_or_other"]
|
|
max_kwargs["dim"] = dim
|
|
max_kwargs["keepdim"] = node.kwargs.get("keepdim", False)
|
|
max_node = node.graph.call_function(nt, kwargs=max_kwargs)
|
|
else:
|
|
other = node.kwargs["dim_or_other"]
|
|
assert isinstance(other, torch.fx.Node)
|
|
# Lowering path for when provided "other", where we do elem-wise max
|
|
nt = new_targets["elementwise"]
|
|
max_kwargs["other"] = other
|
|
max_node = node.graph.call_function(nt, kwargs=max_kwargs)
|
|
max_node.meta = node.meta.copy()
|
|
return max_node
|
|
|
|
|
|
@register_acc_op
|
|
def max_full_reduce(*, input):
|
|
return torch.max(**locals())
|
|
|
|
|
|
@register_acc_op
|
|
def max_dim_reduce(*, input, dim=None, keepdim=False):
|
|
return torch.max(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.maximum))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "maximum"))
|
|
@register_acc_op
|
|
def maximum(*, input, other):
|
|
return torch.maximum(**locals())
|
|
|
|
|
|
@register_acc_op
|
|
def min_full_reduce(*, input):
|
|
return torch.min(input)
|
|
|
|
|
|
@register_acc_op
|
|
def min_dim_reduce(*, input, dim=None, keepdim=False):
|
|
return torch.min(input, dim=dim, keepdim=keepdim)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.minimum))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "minimum"))
|
|
@register_acc_op
|
|
def minimum(*, input, other):
|
|
return torch.minimum(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "sigmoid"))
|
|
@register_acc_op
|
|
def sigmoid(*, input):
|
|
return torch.sigmoid(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.sinh))
|
|
@register_acc_op
|
|
def sinh(*, input):
|
|
return torch.sinh(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.cosh))
|
|
@register_acc_op
|
|
def cosh(*, input):
|
|
return torch.cosh(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.tanh))
|
|
@register_acc_op_mapping(op_and_target=("call_method", "tanh"))
|
|
@register_acc_op
|
|
def tanh(*, input):
|
|
return torch.tanh(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.asin))
|
|
@register_acc_op
|
|
def asin(*, input):
|
|
return torch.asin(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.acos))
|
|
@register_acc_op
|
|
def acos(*, input):
|
|
return torch.acos(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.atan))
|
|
@register_acc_op
|
|
def atan(*, input):
|
|
return torch.atan(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.exp))
|
|
@register_acc_op
|
|
def exp(*, input):
|
|
return torch.exp(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.log))
|
|
@register_acc_op
|
|
def log(*, input):
|
|
return torch.log(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
|
|
@register_acc_op
|
|
def sqrt(*, input):
|
|
return torch.sqrt(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.reciprocal))
|
|
@register_acc_op
|
|
def reciprocal(*, input):
|
|
return torch.reciprocal(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.abs))
|
|
@register_acc_op
|
|
def abs(*, input):
|
|
return torch.abs(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.neg))
|
|
@register_acc_op
|
|
def neg(*, input):
|
|
return torch.neg(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.floor))
|
|
@register_acc_op
|
|
def floor(*, input):
|
|
return torch.floor(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.ceil))
|
|
@register_acc_op
|
|
def ceil(*, input):
|
|
return torch.ceil(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.conv2d))
|
|
@register_acc_op
|
|
def conv2d(*, input, weight, bias, stride, padding, dilation, groups):
|
|
return nn.functional.conv2d(**locals())
|
|
|
|
|
|
@register_acc_op
|
|
def quantized_conv2d(
|
|
*,
|
|
input,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
padding_mode,
|
|
acc_out_ty=None,
|
|
):
|
|
assert acc_out_ty is not None
|
|
return torch.nn.quantized.functional.conv2d(
|
|
input,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
padding_mode,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.batch_norm))
|
|
@register_acc_op
|
|
def batch_norm(
|
|
*, input, running_mean, running_var, weight, bias, training, momentum, eps
|
|
):
|
|
return nn.functional.batch_norm(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.layer_norm))
|
|
@register_acc_op
|
|
def layer_norm(*, input, normalized_shape, weight, bias, eps):
|
|
return nn.functional.layer_norm(**locals())
|
|
|
|
|
|
def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node:
|
|
"""
|
|
Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk
|
|
+ acc_ops.getitem + acc_ops.squeeze (depends on keepdim).
|
|
"""
|
|
input_node = node.kwargs["input"]
|
|
dim = node.kwargs["dim"]
|
|
keepdim = node.kwargs["keepdim"]
|
|
|
|
if dim is None and keepdim:
|
|
raise RuntimeError(
|
|
"We currently don't support argmin/argmax with dim=None and keepdim=True"
|
|
)
|
|
|
|
with node.graph.inserting_before(node):
|
|
if dim is None:
|
|
flatten_kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"start_dim": 0,
|
|
"end_dim": -1,
|
|
}
|
|
flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs)
|
|
flatten_node.meta["type"] = torch.Tensor
|
|
input_node = flatten_node
|
|
dim = -1
|
|
|
|
topk_kwargs = {
|
|
"input": input_node,
|
|
"k": 1,
|
|
"dim": dim,
|
|
"largest": largest,
|
|
"sorted": False,
|
|
}
|
|
topk_node = node.graph.call_function(topk, kwargs=topk_kwargs)
|
|
# It's actually more like NamedTuple but tuple here should be fine.
|
|
topk_node.meta["type"] = tuple
|
|
|
|
getitem_kwargs = {"input": topk_node, "idx": 1}
|
|
getitem_node = node.graph.call_function(getitem, kwargs=getitem_kwargs)
|
|
getitem_node.meta["type"] = torch.Tensor
|
|
output_node = getitem_node
|
|
|
|
if not keepdim:
|
|
squeeze_kwargs = {"input": getitem_node, "dim": dim}
|
|
output_node = node.graph.call_function(squeeze, kwargs=squeeze_kwargs)
|
|
|
|
output_node.meta = node.meta.copy()
|
|
return output_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.argmin),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim"),
|
|
("keepdim", "keepdim"),
|
|
],
|
|
)
|
|
def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Map torch.argmin to acc_ops.flatten (depend on dim) + acc_ops.topk + acc_ops.getitem
|
|
+ acc_ops.squeeze (depends on keepdim).
|
|
"""
|
|
return argmin_max_mapper_impl(node, largest=False)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.linalg.norm))
|
|
@register_acc_op
|
|
def linalg_norm(*, input, ord, dim, keepdim):
|
|
return torch.linalg.norm(**locals())
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "split"),
|
|
arg_replacement_tuples=[
|
|
("tensor", "input"),
|
|
("split_size_or_sections", "split_size_or_sections"),
|
|
("dim", "dim"),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.split),
|
|
arg_replacement_tuples=[
|
|
("tensor", "input"),
|
|
("split_size_or_sections", "split_size_or_sections"),
|
|
("dim", "dim"),
|
|
],
|
|
)
|
|
def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
If split_size_or_sections is sections, map the node to slice_tensors
|
|
+ tuple_construct. Otherwise, if split_size_or_sections is split_size,
|
|
map the node to acc_ops.split.
|
|
"""
|
|
split_size_or_sections = node.kwargs["split_size_or_sections"]
|
|
with node.graph.inserting_before(node):
|
|
if isinstance(split_size_or_sections, int):
|
|
new_kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"split_size": split_size_or_sections,
|
|
"dim": node.kwargs["dim"],
|
|
}
|
|
new_node = node.graph.call_function(split, kwargs=new_kwargs)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
assert isinstance(split_size_or_sections, Sequence)
|
|
start = 0
|
|
slice_nodes = []
|
|
for i in split_size_or_sections:
|
|
assert isinstance(i, int)
|
|
new_kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"dims": (node.kwargs["dim"],),
|
|
"starts": (start,),
|
|
"stops": (start + i,),
|
|
"steps": (1,),
|
|
}
|
|
new_node = node.graph.call_function(slice_tensor, kwargs=new_kwargs)
|
|
new_node.meta["type"] = torch.Tensor
|
|
slice_nodes.append(new_node)
|
|
start += i
|
|
|
|
new_node = node.graph.call_function(
|
|
tuple_construct, kwargs={"tensors": tuple(slice_nodes)}
|
|
)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op
|
|
def split(*, input, split_size, dim):
|
|
return torch.split(input, split_size, dim)
|
|
|
|
|
|
@register_acc_op
|
|
def tuple_construct(*, tensors):
|
|
return tuple(tensors)
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.ops.quantized.batch_norm2d),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("weight", "weight"),
|
|
("bias", "bias"),
|
|
("running_mean", "running_mean"),
|
|
("running_var", "running_var"),
|
|
("eps", "eps"),
|
|
("scale", "scale"),
|
|
("zero_point", "zero_point"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[
|
|
("scale", "q_scale"),
|
|
("zero_point", "q_zero_point"),
|
|
],
|
|
)
|
|
@register_acc_op
|
|
def quantized_batch_norm2d(
|
|
*, input, running_mean, running_var, weight, bias, eps, acc_out_ty
|
|
):
|
|
return torch.ops.quantized.batch_norm2d(
|
|
input,
|
|
weight,
|
|
bias,
|
|
running_mean,
|
|
running_var,
|
|
eps,
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
|
|
acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
|
|
)
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding_bag))
|
|
@register_acc_op
|
|
def embedding_bag(
|
|
*,
|
|
input,
|
|
weight,
|
|
offsets,
|
|
max_norm,
|
|
norm_type,
|
|
scale_grad_by_freq,
|
|
mode,
|
|
sparse,
|
|
per_sample_weights,
|
|
include_last_offset,
|
|
padding_idx,
|
|
):
|
|
return nn.functional.embedding_bag(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=(
|
|
"call_function",
|
|
torch.ops.quantized.embedding_bag_byte_rowwise_offsets,
|
|
)
|
|
)
|
|
@register_acc_op
|
|
def embedding_bag_byte_rowwise_offsets(
|
|
*,
|
|
weight,
|
|
indices,
|
|
offsets,
|
|
scale_grad_by_freq,
|
|
mode,
|
|
pruned_weights,
|
|
per_sample_weights,
|
|
compressed_indices_mapping,
|
|
include_last_offset,
|
|
):
|
|
return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=(
|
|
"call_function",
|
|
torch.ops.quantized.embedding_bag_4bit_rowwise_offsets,
|
|
)
|
|
)
|
|
@register_acc_op
|
|
def embedding_bag_4bit_rowwise_offsets(
|
|
*,
|
|
weight,
|
|
indices,
|
|
offsets,
|
|
scale_grad_by_freq,
|
|
mode,
|
|
pruned_weights,
|
|
per_sample_weights,
|
|
compressed_indices_mapping,
|
|
include_last_offset,
|
|
):
|
|
return torch.ops.quantized.embedding_bag_4bit_rowwise_offsets(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.sin))
|
|
@register_acc_op
|
|
def sin(*, input):
|
|
return torch.sin(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.cos))
|
|
@register_acc_op
|
|
def cos(*, input):
|
|
return torch.cos(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.tan))
|
|
@register_acc_op
|
|
def tan(*, input):
|
|
return torch.tan(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", torch.topk))
|
|
@register_acc_op
|
|
def topk(*, input, k, dim, largest, sorted):
|
|
return torch.topk(**locals())
|
|
|
|
|
|
@register_acc_op_mapping(op_and_target=("call_function", operator.getitem))
|
|
@register_acc_op
|
|
def getitem(*, input, idx):
|
|
return input[idx]
|
|
|
|
|
|
@register_acc_op
|
|
def slice_tensor(*, input, dims, starts, stops, steps):
|
|
slices: List[Optional[slice]] = [None for _ in range(input.dim())]
|
|
|
|
# For all provided dims, extract out a slice for starts/stops/steps.
|
|
for idx, dim in enumerate(dims):
|
|
slices[dim] = slice(starts[idx], stops[idx], steps[idx])
|
|
|
|
# For all unspecified dims, default to the full slice.
|
|
for idx, s in enumerate(slices):
|
|
if s is None:
|
|
slices[idx] = slice(None, None, None)
|
|
|
|
return input[slices]
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.narrow),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim"),
|
|
("start", "start"),
|
|
("length", "length"),
|
|
],
|
|
)
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "narrow"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dim", "dim"),
|
|
("start", "start"),
|
|
("length", "length"),
|
|
],
|
|
)
|
|
def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
|
|
assert isinstance(node.kwargs["start"], int) and isinstance(
|
|
node.kwargs["length"], int
|
|
)
|
|
kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"dims": (node.kwargs["dim"],),
|
|
"starts": (node.kwargs["start"],),
|
|
"stops": (node.kwargs["start"] + node.kwargs["length"],),
|
|
"steps": (1,),
|
|
}
|
|
with node.graph.inserting_before(node):
|
|
new_node = node.graph.call_function(slice_tensor, kwargs=kwargs)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_function", torch.reshape),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("shape", "shape"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
|
|
)
|
|
@register_acc_op_mapping(
|
|
op_and_target=("call_method", "view"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("*", "shape"),
|
|
],
|
|
kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
|
|
)
|
|
@register_acc_op
|
|
def reshape(*, input, acc_out_ty=None):
|
|
assert acc_out_ty is not None
|
|
return torch.reshape(
|
|
input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape"))
|
|
)
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "reshape"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("*", "shape"),
|
|
],
|
|
)
|
|
def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)).
|
|
Here we do some special handling with the `shape` arg in order to map it to
|
|
acc_ops.reshape. It also handles the case when `shape` is a list instead of
|
|
tuple.
|
|
"""
|
|
input_node = node.kwargs["input"]
|
|
shape = node.kwargs["shape"]
|
|
|
|
assert isinstance(shape, Sequence)
|
|
if isinstance(shape[0], (tuple, list)): # type: ignore[index]
|
|
shape = shape[0] # type: ignore[index]
|
|
|
|
with node.graph.inserting_before(node):
|
|
new_node = node.graph.call_function(
|
|
reshape,
|
|
kwargs={
|
|
"input": input_node,
|
|
"acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape),
|
|
},
|
|
)
|
|
new_node.meta = node.meta.copy()
|
|
return new_node
|
|
|
|
|
|
@register_acc_op
|
|
def to_dtype(input, acc_out_ty=None):
|
|
assert acc_out_ty is not None, "valid acc_out_ty needed"
|
|
return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"))
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_method", "to"),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("dtype", "dtype"),
|
|
],
|
|
)
|
|
def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module):
|
|
dest_dtype = node.kwargs["dtype"]
|
|
mem_format = node.kwargs.get("memory_format")
|
|
device = node.kwargs.get("device")
|
|
assert dest_dtype is not None
|
|
assert mem_format is None or mem_format == torch.preserve_format
|
|
assert device is None
|
|
|
|
new_kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype),
|
|
}
|
|
|
|
with node.graph.inserting_before(node):
|
|
new_node = node.graph.create_node(
|
|
"call_function", to_dtype, kwargs=new_kwargs, name=node.name
|
|
)
|
|
new_node.meta = node.meta
|
|
return new_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.add),
|
|
# Note that we may have aliases for inputs here due to issues with deterministically
|
|
# knowing the correct target that will be resolved by pytorch.
|
|
arg_replacement_tuples=[
|
|
(("input", "a"), "input"),
|
|
(("other", "b"), "other"),
|
|
("alpha", "alpha", this_arg_is_optional),
|
|
],
|
|
)
|
|
def custom_torch_add_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
|
|
"""
|
|
Add custom mapping for torch.add because it has an `alpha` parameter which scales
|
|
the `other` input, and we want to make that mul a separate node.
|
|
"""
|
|
with node.graph.inserting_before(node):
|
|
# If alpha is in kwargs check if we need to add a mul, and use correct kwargs.
|
|
if "alpha" in node.kwargs:
|
|
# Add mul node only if it has a numerical impact, i.e. alpha != 1.0.
|
|
if node.kwargs["alpha"] != 1.0:
|
|
other_node = node.graph.create_node(
|
|
"call_function",
|
|
mul,
|
|
kwargs={
|
|
"input": node.kwargs["other"],
|
|
"other": node.kwargs["alpha"],
|
|
},
|
|
name=node.name + "_mul_alpha",
|
|
)
|
|
other_node.meta = node.meta
|
|
else:
|
|
other_node = node.kwargs["other"]
|
|
add_kwargs = {"input": node.kwargs["input"], "other": other_node}
|
|
else:
|
|
add_kwargs = node.kwargs
|
|
|
|
new_node = node.graph.create_node(
|
|
"call_function", add, kwargs=add_kwargs, name=node.name
|
|
)
|
|
new_node.meta = node.meta
|
|
return new_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_module", nn.quantized.Linear),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def packed_quantized_linear_mapper(
|
|
node: torch.fx.Node, mod: nn.Module
|
|
) -> torch.fx.Node:
|
|
"""
|
|
Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias
|
|
in this mapper and pass them directly to linear node.
|
|
"""
|
|
assert isinstance(node.target, str)
|
|
linear_module = dict(mod.named_modules())[node.target]
|
|
prefix = node.target.replace(".", "_")
|
|
weight_name = f"{prefix}_weight"
|
|
bias_name = f"{prefix}_bias"
|
|
|
|
# Store weight and bias in the main module
|
|
mod.register_buffer(weight_name, linear_module.weight())
|
|
if linear_module.bias() is not None:
|
|
mod.register_buffer(bias_name, linear_module.bias())
|
|
|
|
with node.graph.inserting_before(node):
|
|
# Insert get_attr nodes for weight and bias
|
|
get_weight = node.graph.get_attr(weight_name)
|
|
get_weight.meta["tensor_meta"] = _extract_tensor_metadata(
|
|
linear_module.weight()
|
|
)
|
|
|
|
get_bias = None
|
|
if linear_module.bias() is not None:
|
|
get_bias = node.graph.get_attr(bias_name)
|
|
get_bias.meta["tensor_meta"] = _extract_tensor_metadata(
|
|
linear_module.bias()
|
|
)
|
|
|
|
# Create kwargs for acc_op.quantized_linear
|
|
kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"weight": get_weight,
|
|
"bias": get_bias,
|
|
"acc_out_ty": acc_utils.build_raw_tensor_meta(
|
|
q_scale=linear_module.scale, q_zero_point=linear_module.zero_point
|
|
),
|
|
}
|
|
|
|
new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
|
|
new_node.meta = node.meta
|
|
return new_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_module", nn.quantized.Conv2d),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def packed_quantized_conv2d_mapper(
|
|
node: torch.fx.Node, mod: nn.Module
|
|
) -> torch.fx.Node:
|
|
"""
|
|
Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters
|
|
in this mapper and pass them directly to conv2d node.
|
|
"""
|
|
assert isinstance(node.target, str)
|
|
conv_module = dict(mod.named_modules())[node.target]
|
|
prefix = node.target.replace(".", "_")
|
|
weight_name = f"{prefix}_weight"
|
|
bias_name = f"{prefix}_bias"
|
|
|
|
# Store weight and bias in the main module
|
|
mod.register_buffer(weight_name, conv_module.weight())
|
|
if conv_module.bias() is not None:
|
|
mod.register_buffer(bias_name, conv_module.bias())
|
|
|
|
with node.graph.inserting_before(node):
|
|
# Insert get_attr nodes for weight and bias
|
|
get_weight = node.graph.get_attr(weight_name)
|
|
get_weight.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.weight())
|
|
|
|
get_bias = None
|
|
if conv_module.bias() is not None:
|
|
get_bias = node.graph.get_attr(bias_name)
|
|
get_bias.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.bias())
|
|
|
|
# Create kwargs for acc_op.conv
|
|
kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"weight": get_weight,
|
|
"bias": get_bias,
|
|
"stride": conv_module.stride,
|
|
"padding": conv_module.padding,
|
|
"dilation": conv_module.dilation,
|
|
"groups": conv_module.groups,
|
|
"padding_mode": conv_module.padding_mode,
|
|
"acc_out_ty": acc_utils.build_raw_tensor_meta(
|
|
q_scale=conv_module.scale, q_zero_point=conv_module.zero_point
|
|
),
|
|
}
|
|
|
|
new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs)
|
|
new_node.meta = node.meta
|
|
return new_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_function", torch.ops.quantized.add_relu),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
("other", "other"),
|
|
("scale", "scale"),
|
|
("zero_point", "zero_point"),
|
|
],
|
|
)
|
|
def add_relu_unfuse_mapper(
|
|
node: torch.fx.Node, mod: torch.fx.GraphModule
|
|
) -> torch.fx.Node:
|
|
with node.graph.inserting_before(node):
|
|
add_kwargs = {
|
|
"input": node.kwargs["input"],
|
|
"other": node.kwargs["other"],
|
|
"acc_out_ty": acc_utils.build_raw_tensor_meta(
|
|
q_scale=node.kwargs["scale"],
|
|
q_zero_point=node.kwargs["zero_point"],
|
|
),
|
|
}
|
|
add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs)
|
|
add_node.meta = node.meta.copy()
|
|
|
|
relu_node = node.graph.call_function(
|
|
relu, kwargs={"input": add_node, "inplace": False}
|
|
)
|
|
relu_node.meta = node.meta
|
|
return relu_node
|
|
|
|
|
|
@register_custom_acc_mapper_fn(
|
|
op_and_target=("call_module", nn.intrinsic.quantized.ConvReLU2d),
|
|
arg_replacement_tuples=[
|
|
("input", "input"),
|
|
],
|
|
)
|
|
def packed_quantized_convrelu2d_mapper(
|
|
node: torch.fx.Node, mod: nn.Module
|
|
) -> torch.fx.Node:
|
|
"""
|
|
Mapping from quantized ConvReLU2d module to acc_op.relu. We use packed_quantized_conv2d_mapper to unpack all the parameters
|
|
in this mapper and pass the returned conv2d node directly to relu node.
|
|
"""
|
|
|
|
with node.graph.inserting_before(node):
|
|
# conv2d op
|
|
conv2d_node = packed_quantized_conv2d_mapper(node, mod)
|
|
|
|
# relu op
|
|
relu_node = node.graph.call_function(
|
|
relu, kwargs={"input": conv2d_node, "inplace": False}
|
|
)
|
|
relu_node.meta = node.meta
|
|
return relu_node
|