mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add truediv
support in export serializer (#136364)
Fixes #136113 - [x] Inital `truediv` coverage - [ ] Expand/reduce coverage? - [x] Add tests - [x] Re-check docstrings - [ ] Linting Pull Request resolved: https://github.com/pytorch/pytorch/pull/136364 Approved by: https://github.com/pianpwk Co-authored-by: Angela Yi <angelayi@meta.com> Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This commit is contained in:
@ -890,6 +890,7 @@ API Reference
|
||||
.. autoclass:: OutputSpec
|
||||
.. autoclass:: SymIntArgument
|
||||
.. autoclass:: SymBoolArgument
|
||||
.. autoclass:: SymFloatArgument
|
||||
.. autoclass:: ExportGraphSignature
|
||||
|
||||
.. automethod:: replace_all_uses
|
||||
|
@ -870,6 +870,17 @@ graph():
|
||||
ep_res = torch.export.export(M(), input).module()(*input)
|
||||
self.assertEqual(orig_res, ep_res)
|
||||
|
||||
def test_symfloat_item(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, tensor):
|
||||
return tensor.item()
|
||||
|
||||
input = (torch.tensor([3.14], dtype=torch.float),)
|
||||
|
||||
orig_res = M()(*input)
|
||||
ep_res = torch.export.export(M(), input).module()(*input)
|
||||
self.assertEqual(orig_res, ep_res)
|
||||
|
||||
def test_unbacked_to_cond(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
@ -6544,8 +6555,6 @@ graph():
|
||||
ep = export(m, inputs)
|
||||
self.assertEqual(ep.module()(*inputs), m(*inputs))
|
||||
|
||||
@testing.expectedFailureSerDer # symfloat nyi
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_sym_sqrt(self):
|
||||
import math
|
||||
|
||||
|
@ -322,6 +322,31 @@ def forward(self, x):
|
||||
self.assertEqual(node.inputs[0].name, "self")
|
||||
self.assertEqual(node.inputs[1].name, "dim")
|
||||
|
||||
def test_serialize_sym_float(self) -> None:
|
||||
class DynamicFloatSimpleModel(torch.nn.Module):
|
||||
def __init__(self, multiplier: torch.SymFloat):
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
def forward(self, a, b, c) -> torch.Tensor:
|
||||
d = (torch.matmul(a, b) + c) / 2
|
||||
e = d * self.multiplier
|
||||
e_s0 = e.shape[0]
|
||||
e_s1 = e.shape[1]
|
||||
e_s3 = e_s0 * e_s1
|
||||
f = e.view(e_s3)
|
||||
return torch.cat([f, f])
|
||||
|
||||
multiplier_sym = torch.SymFloat("multiplier_sym")
|
||||
model = DynamicFloatSimpleModel(multiplier_sym)
|
||||
inputs = (
|
||||
torch.randn(2, 4),
|
||||
torch.randn(4, 7),
|
||||
torch.randn(2, 7),
|
||||
)
|
||||
dim0_ac = Dim("dim0_ac")
|
||||
dim1_bc = Dim("dim1_b")
|
||||
|
||||
def test_serialize_infinite_sym_int(self) -> None:
|
||||
class DynamicShapeSimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -2712,6 +2712,14 @@ class AOTInductorTestsTemplate:
|
||||
inputs = (torch.tensor([0], dtype=torch.bool, device=self.device),)
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_symfloat_item(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, tensor):
|
||||
return tensor.item()
|
||||
|
||||
inputs = (torch.tensor([3.14], dtype=torch.float, device=self.device),)
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_constant_original_fqn_and_dtype(self):
|
||||
class FooBarModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -166,6 +166,8 @@ CPU_TEST_FAILURES = {
|
||||
"test_symint_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
|
||||
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
|
||||
"test_symfloat_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
"test_issue_140766": fail_minimal_arrayref_interface(),
|
||||
}
|
||||
|
||||
|
@ -37,6 +37,7 @@ from torch.export.graph_signature import (
|
||||
OutputSpec,
|
||||
SymIntArgument,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
TensorArgument,
|
||||
)
|
||||
from torch.fx import traceback as fx_traceback
|
||||
|
@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from torch._export.serde.union import _Union
|
||||
|
||||
# NOTE: Please update this value if any modifications are made to the schema
|
||||
SCHEMA_VERSION = (8, 1)
|
||||
SCHEMA_VERSION = (8, 2)
|
||||
TREESPEC_VERSION = 1
|
||||
|
||||
|
||||
@ -77,6 +77,11 @@ class SymInt(_Union):
|
||||
as_expr: SymExpr
|
||||
as_int: int
|
||||
|
||||
@dataclass(repr=False)
|
||||
class SymFloat(_Union):
|
||||
as_expr: SymExpr
|
||||
as_int: float
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class SymBool(_Union):
|
||||
@ -106,6 +111,16 @@ class SymIntArgument(_Union):
|
||||
as_name: str
|
||||
as_int: int
|
||||
|
||||
# In most cases we will use the "as_name" field to store arguments which are
|
||||
# SymFloats.
|
||||
# The "as_float" field is used in the case where we have a list containing a mix
|
||||
# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to
|
||||
# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints
|
||||
# to the "as_float" field.
|
||||
@dataclass(repr=False)
|
||||
class SymFloatArgument(_Union):
|
||||
as_name: str
|
||||
as_float: float
|
||||
|
||||
# In most cases we will use the "as_name" field to store arguments which are
|
||||
# SymBools.
|
||||
@ -165,6 +180,8 @@ class Argument(_Union):
|
||||
as_strings: List[str]
|
||||
as_sym_int: SymIntArgument
|
||||
as_sym_ints: List[SymIntArgument]
|
||||
as_sym_float: SymFloatArgument
|
||||
as_sym_floats: List[SymFloatArgument]
|
||||
as_scalar_type: ScalarType
|
||||
as_memory_format: MemoryFormat
|
||||
as_layout: Layout
|
||||
@ -208,6 +225,7 @@ class Graph:
|
||||
# list.
|
||||
is_single_tensor_return: bool = False
|
||||
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
|
||||
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
|
||||
# checksum<<b2d7665a2d5d77eca43ac97af5e691123dd82b7b2582b8e81f2c326761e2f649>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
@ -25,6 +25,10 @@ Argument:
|
||||
type: SymIntArgument
|
||||
as_sym_ints:
|
||||
type: List[SymIntArgument]
|
||||
as_sym_float:
|
||||
type: SymFloatArgument
|
||||
as_sym_floats:
|
||||
type: List[SymFloatArgument]
|
||||
as_scalar_type:
|
||||
type: ScalarType
|
||||
as_memory_format:
|
||||
@ -136,6 +140,9 @@ Graph:
|
||||
custom_obj_values:
|
||||
type: Dict[str, CustomObjArgument]
|
||||
default: '{}'
|
||||
sym_float_values:
|
||||
type: Dict[str, SymFloat]
|
||||
default: '{}'
|
||||
GraphArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -377,6 +384,20 @@ SymExprHint:
|
||||
type: float
|
||||
as_bool:
|
||||
type: bool
|
||||
SymFloat:
|
||||
kind: union
|
||||
fields:
|
||||
as_expr:
|
||||
type: SymExpr
|
||||
as_int:
|
||||
type: float
|
||||
SymFloatArgument:
|
||||
kind: union
|
||||
fields:
|
||||
as_name:
|
||||
type: str
|
||||
as_float:
|
||||
type: float
|
||||
SymInt:
|
||||
kind: union
|
||||
fields:
|
||||
@ -437,5 +458,5 @@ UserOutputSpec:
|
||||
type: Argument
|
||||
SCHEMA_VERSION:
|
||||
- 8
|
||||
- 1
|
||||
- 2
|
||||
TREESPEC_VERSION: 1
|
||||
|
@ -87,6 +87,8 @@ from .schema import ( # type: ignore[attr-defined]
|
||||
SymExprHint,
|
||||
SymInt,
|
||||
SymIntArgument,
|
||||
SymFloat,
|
||||
SymFloatArgument,
|
||||
TensorArgument,
|
||||
TensorMeta,
|
||||
TokenArgument,
|
||||
@ -118,7 +120,7 @@ def _reverse_map(d: Dict[Any, Enum]):
|
||||
|
||||
|
||||
MetaType = Union[
|
||||
FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument
|
||||
FakeTensor, int, torch.SymInt, float, torch.SymFloat, bool, torch.SymBool, ep.CustomObjArgument
|
||||
]
|
||||
|
||||
|
||||
@ -169,6 +171,9 @@ _TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
|
||||
|
||||
_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type]
|
||||
|
||||
_SYM_FLOAT_OPS = {
|
||||
operator.truediv,
|
||||
}
|
||||
|
||||
_SYM_INT_OPS = {
|
||||
operator.mul,
|
||||
@ -201,7 +206,7 @@ _SYM_BOOL_OPS = {
|
||||
|
||||
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_INT_OPS)
|
||||
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_BOOL_OPS)
|
||||
|
||||
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_FLOAT_OPS)
|
||||
|
||||
@dataclass
|
||||
class SerializedArtifact:
|
||||
@ -251,6 +256,25 @@ def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
|
||||
f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
|
||||
)
|
||||
|
||||
def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat:
|
||||
if isinstance(s, (torch.SymFloat, sympy.Symbol, float)):
|
||||
if symbolic_shapes.is_concrete_float(s):
|
||||
return SymFloat.create(as_float=float(s))
|
||||
else:
|
||||
assert isinstance(s, (torch.SymFloat, sympy.Symbol))
|
||||
if s.node.hint is None:
|
||||
return SymFloat.create(as_expr=SymExpr(_print_sympy(s)))
|
||||
else:
|
||||
return SymFloat.create(
|
||||
as_expr=SymExpr(
|
||||
_print_sympy(s),
|
||||
hint=SymExprHint.create(as_float=s.node.hint),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise SerializeError(
|
||||
f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`"
|
||||
)
|
||||
|
||||
def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
|
||||
if isinstance(s, (torch.SymBool, bool)):
|
||||
@ -425,6 +449,7 @@ class GraphState:
|
||||
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
|
||||
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
|
||||
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
|
||||
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
|
||||
is_single_tensor_return: bool = False
|
||||
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
|
||||
|
||||
@ -468,6 +493,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
)
|
||||
elif isinstance(node.meta["val"], torch.SymInt):
|
||||
raise AssertionError("SymInt graph input is not implemented yet.")
|
||||
elif isinstance(node.meta["val"], torch.SymFloat):
|
||||
raise AssertionError("SymFloat graph input is not implemented yet.")
|
||||
elif isinstance(node.meta["val"], (int, bool, str, float, type(None))):
|
||||
graph_input = self.serialize_input(node.meta["val"])
|
||||
elif isinstance(node.meta["val"], ep.CustomObjArgument):
|
||||
@ -516,21 +543,25 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
if (
|
||||
node.target in _SYM_INT_OPS
|
||||
or node.target in _SYM_BOOL_OPS
|
||||
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool)))
|
||||
or node.target in _SYM_FLOAT_OPS
|
||||
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)))
|
||||
):
|
||||
assert len(node.kwargs) == 0
|
||||
|
||||
# Serialize the node
|
||||
if isinstance(meta_val, torch.SymInt):
|
||||
sym_output = Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))
|
||||
elif isinstance(meta_val, torch.SymFloat):
|
||||
sym_output = Argument.create(as_sym_float=self.serialize_sym_float_output(node.name, meta_val))
|
||||
elif isinstance(meta_val, torch.SymBool):
|
||||
sym_output = Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))
|
||||
else:
|
||||
raise SerializeError(f"Unsupported symbolic type: {type(meta_val)}")
|
||||
|
||||
ex_node = Node(
|
||||
target=self.serialize_operator(node.target),
|
||||
inputs=self.serialize_sym_op_inputs(node.target, node.args),
|
||||
outputs=[
|
||||
Argument.create(
|
||||
as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
|
||||
)
|
||||
if (node.target in _SYM_INT_OPS or isinstance(meta_val, torch.SymInt))
|
||||
else Argument.create(
|
||||
as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
|
||||
)
|
||||
],
|
||||
outputs=[sym_output],
|
||||
metadata=self.serialize_metadata(node),
|
||||
)
|
||||
elif isinstance(node.target, torch._ops.OpOverload):
|
||||
@ -645,7 +676,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
args_names = [arg.name for arg in op._schema.arguments]
|
||||
else:
|
||||
assert op in _SYM_INT_OPS or op in _SYM_BOOL_OPS
|
||||
assert op in _SYM_INT_OPS or op in _SYM_BOOL_OPS or op in _SYM_FLOAT_OPS
|
||||
args_names = list(inspect.signature(op).parameters.keys())
|
||||
serialized_args = []
|
||||
for args_name, arg in zip(args_names, args):
|
||||
@ -713,6 +744,12 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
and arg.name in self.graph_state.sym_int_values
|
||||
)
|
||||
|
||||
def is_sym_float_arg(self, arg) -> bool:
|
||||
return isinstance(arg, float) or (
|
||||
isinstance(arg, torch.fx.Node)
|
||||
and arg.name in self.graph_state.sym_float_values
|
||||
)
|
||||
|
||||
def is_sym_bool_arg(self, arg) -> bool:
|
||||
return isinstance(arg, bool) or (
|
||||
isinstance(arg, torch.fx.Node)
|
||||
@ -752,6 +789,10 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
return Argument.create(
|
||||
as_sym_int=SymIntArgument.create(as_name=arg.name)
|
||||
)
|
||||
elif self.is_sym_float_arg(arg):
|
||||
return Argument.create(
|
||||
as_sym_float=SymFloatArgument.create(as_name=arg.name)
|
||||
)
|
||||
elif self.is_sym_bool_arg(arg):
|
||||
return Argument.create(
|
||||
as_sym_bool=SymBoolArgument.create(as_name=arg.name)
|
||||
@ -781,6 +822,12 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
# For regular FX graph, SymInt arg should be a fx.Node with
|
||||
# self.is_sym_int_arg(arg) being true
|
||||
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
|
||||
elif isinstance(arg, torch.SymFloat):
|
||||
# This is a special branch for handling SymFloat args in inductor's
|
||||
# ExternalFallbackNode.
|
||||
# For regular FX graph, SymInt arg should be a fx.Node with
|
||||
# self.is_sym_float_arg(arg) being true
|
||||
return Argument.create(as_sym_float=SymFloatArgument.create(as_name=str(arg)))
|
||||
elif type(arg) is bool:
|
||||
return Argument.create(as_bool=arg)
|
||||
elif type(arg) is str:
|
||||
@ -841,6 +888,10 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
return Argument.create(
|
||||
as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg]
|
||||
)
|
||||
elif all(isinstance(a, torch.SymFloat) for a in arg):
|
||||
return Argument.create(
|
||||
as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg]
|
||||
)
|
||||
elif all(self.is_sym_int_arg(a) for a in arg):
|
||||
# list of sym_ints
|
||||
values = []
|
||||
@ -850,6 +901,15 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
elif type(a) is int:
|
||||
values.append(SymIntArgument.create(as_int=a))
|
||||
return Argument.create(as_sym_ints=values)
|
||||
elif all(self.is_sym_float_arg(a) for a in arg):
|
||||
# list of sym_float
|
||||
values = []
|
||||
for a in arg:
|
||||
if isinstance(a, torch.fx.Node):
|
||||
values.append(SymFloatArgument.create(as_name=a.name))
|
||||
elif isinstance(a, float):
|
||||
values.append(SymFloatArgument.create(as_float=a))
|
||||
return Argument.create(as_sym_floats=values)
|
||||
elif all(self.is_sym_bool_arg(a) for a in arg):
|
||||
# list of sym_bools
|
||||
values = []
|
||||
@ -954,6 +1014,11 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
|
||||
return SymIntArgument.create(as_name=name)
|
||||
|
||||
def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument:
|
||||
assert name not in self.graph_state.sym_float_values
|
||||
self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val)
|
||||
return SymFloatArgument.create(as_name=name)
|
||||
|
||||
def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
|
||||
assert name not in self.graph_state.sym_bool_values
|
||||
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
|
||||
@ -1102,6 +1167,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
return Argument.create(as_tensor=TensorArgument(name=x.name))
|
||||
elif isinstance(x, ep.SymIntArgument):
|
||||
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
|
||||
elif isinstance(x, ep.SymFloatArgument):
|
||||
return Argument.create(as_sym_Float=SymFloatArgument.create(as_name=x.name))
|
||||
elif isinstance(x, ep.ConstantArgument):
|
||||
return self.serialize_input(x.value)
|
||||
elif isinstance(x, ep.CustomObjArgument):
|
||||
@ -1219,7 +1286,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
sub_user_node_name = self._output_node_name_at_index(user_node, i)
|
||||
args.append(self.serialize_tensor_output(sub_user_node_name, m))
|
||||
output_arguments.append(Argument.create(as_tensors=args))
|
||||
elif isinstance(meta, (int, SymInt)):
|
||||
elif isinstance(meta, (int, SymInt, float, SymFloat)):
|
||||
user_node_name = self._output_node_name_at_index(node, idx)
|
||||
output_arguments.append(self.serialize_output(user_node_name, meta))
|
||||
else:
|
||||
@ -1291,6 +1358,11 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
return Argument.create(
|
||||
as_sym_int=self.serialize_sym_int_output(name, meta_val)
|
||||
)
|
||||
elif isinstance(meta_val, (int, torch.SymFloat)):
|
||||
# e.g "-> SymFloat"
|
||||
return Argument.create(
|
||||
as_sym_float=self.serialize_sym_float_output(name, meta_val)
|
||||
)
|
||||
elif isinstance(meta_val, torch.SymBool):
|
||||
# e.g "-> SymBool"
|
||||
return Argument.create(
|
||||
@ -1340,6 +1412,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
nodes=self.graph_state.nodes,
|
||||
tensor_values=self.graph_state.tensor_values,
|
||||
sym_int_values=self.graph_state.sym_int_values,
|
||||
sym_float_values=self.graph_state.sym_float_values,
|
||||
sym_bool_values=self.graph_state.sym_bool_values,
|
||||
custom_obj_values=self.graph_state.custom_obj_values,
|
||||
outputs=self.graph_state.outputs,
|
||||
@ -1557,6 +1630,20 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
f"SymInt has invalid field type {s.type} with value {s.value}"
|
||||
)
|
||||
|
||||
def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]:
|
||||
val = s.value
|
||||
if s.type == "as_expr":
|
||||
hint = val.hint.as_float if val.hint else None
|
||||
sym = self._parse_sym_expr(val.expr_str, hint)
|
||||
return self.shape_env.create_symfloatnode(sym, hint=hint)
|
||||
elif s.type == "as_float":
|
||||
assert isinstance(val, float)
|
||||
return val
|
||||
else:
|
||||
raise SerializeError(
|
||||
f"SymFloat has invalid field type {s.type} with value {s.value}"
|
||||
)
|
||||
|
||||
def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
|
||||
val = s.value
|
||||
if s.type == "as_expr":
|
||||
@ -1600,8 +1687,12 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
return self.serialized_name_to_node[output.as_sym_int.as_name]
|
||||
elif output.type == "as_sym_bool":
|
||||
return self.serialized_name_to_node[output.as_sym_bool.as_name]
|
||||
elif output.type == "as_sym_float":
|
||||
return self.serialized_name_to_node[output.as_sym_float.as_name]
|
||||
elif output.type == "as_int":
|
||||
return output.as_int
|
||||
elif output.type == "as_float":
|
||||
return output.as_float
|
||||
elif output.type == "as_none":
|
||||
return None
|
||||
else:
|
||||
@ -1616,6 +1707,9 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
for name, sym_int_value in serialized_graph.sym_int_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
|
||||
|
||||
for name, sym_int_value in serialized_graph.sym_float_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_float(sym_int_value)
|
||||
|
||||
for name, sym_bool_value in serialized_graph.sym_bool_values.items():
|
||||
self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
|
||||
sym_bool_value
|
||||
@ -1628,7 +1722,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
|
||||
# Inputs: convert to placeholder nodes in FX.
|
||||
for i, input_ in enumerate(serialized_graph.inputs):
|
||||
if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"):
|
||||
if input_.type in ("as_tensor", "as_sym_int", "as_sym_float", "as_custom_obj"):
|
||||
node_name = input_.value.name
|
||||
placeholder_node = self.graph.placeholder(node_name)
|
||||
# FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
|
||||
@ -1684,6 +1778,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
if (
|
||||
target in _SYM_BOOL_OPS
|
||||
or target in _SYM_INT_OPS
|
||||
or target in _SYM_FLOAT_OPS
|
||||
or target == torch.ops.aten.item.default # this can produce either SymInt or SymBool
|
||||
):
|
||||
name = serialized_node.outputs[0].value.as_name
|
||||
@ -2021,6 +2116,8 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
return inp.as_string
|
||||
elif typ_ == "as_sym_int":
|
||||
return self.deserialize_sym_argument(inp.as_sym_int)
|
||||
elif typ_ == "as_sym_float":
|
||||
return self.deserialize_sym_argument(inp.as_sym_float)
|
||||
elif typ_ == "as_sym_bool":
|
||||
return self.deserialize_sym_argument(inp.as_sym_bool)
|
||||
elif isinstance(value, list):
|
||||
@ -2032,7 +2129,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
|
||||
# convert from serialized.python.types.List to python list
|
||||
return list(value)
|
||||
elif typ_ in ("as_sym_ints", "as_sym_bools"):
|
||||
elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"):
|
||||
return [self.deserialize_sym_argument(arg) for arg in value]
|
||||
elif typ_ == "as_optional_tensors":
|
||||
|
||||
@ -2077,6 +2174,11 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
return sym_arg.as_int
|
||||
elif sym_arg.type == "as_name":
|
||||
return self.serialized_name_to_node[sym_arg.as_name]
|
||||
elif isinstance(sym_arg, SymFloatArgument):
|
||||
if sym_arg.type == "as_float":
|
||||
return sym_arg.as_float
|
||||
elif sym_arg.type == "as_name":
|
||||
return self.serialized_name_to_node[sym_arg.as_name]
|
||||
elif isinstance(sym_arg, SymBoolArgument):
|
||||
if sym_arg.type == "as_bool":
|
||||
return sym_arg.as_bool
|
||||
@ -2098,7 +2200,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||
return
|
||||
elif len(serialized_node.outputs) == 1 and isinstance(
|
||||
serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)
|
||||
serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)
|
||||
):
|
||||
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
|
||||
return
|
||||
@ -2113,13 +2215,15 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
def generate_getitem(
|
||||
meta_val,
|
||||
fx_node: torch.fx.Node,
|
||||
arg: Union[TensorArgument, SymIntArgument],
|
||||
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
|
||||
idx: int,
|
||||
):
|
||||
if isinstance(arg, TensorArgument):
|
||||
name = arg.name
|
||||
elif isinstance(arg, SymIntArgument):
|
||||
name = arg.as_name
|
||||
elif isinstance(arg, SymFloatArgument):
|
||||
name = arg.as_name
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"generate_getitem got unknown argument type {type(arg)}"
|
||||
@ -2140,7 +2244,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, Argument):
|
||||
arg = arg.value
|
||||
if isinstance(arg, (TensorArgument, SymIntArgument)):
|
||||
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
|
||||
generate_getitem(meta_val, fx_node, arg, idx)
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
list_output = self.graph.create_node(
|
||||
@ -2241,6 +2345,8 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
return ep.TensorArgument(name=x.as_tensor.name)
|
||||
elif x.type == "as_sym_int":
|
||||
return ep.SymIntArgument(name=x.as_sym_int.as_name)
|
||||
elif x.type == "as_sym_float":
|
||||
return ep.SymFloatArgument(name=x.as_sym_float.as_name)
|
||||
elif x.type == "as_custom_obj":
|
||||
return ep.ConstantArgument(name=x.as_custom_obj.name, value=self.deserialize_input(x))
|
||||
else:
|
||||
@ -2486,6 +2592,10 @@ def _canonicalize_graph(
|
||||
return a.as_sym_int
|
||||
elif a.type == "as_sym_ints":
|
||||
return a.as_sym_ints
|
||||
elif a.type == "as_sym_float":
|
||||
return a.as_sym_float
|
||||
elif a.type == "as_sym_floats":
|
||||
return a.as_sym_floats
|
||||
elif a.type == "as_scalar_type":
|
||||
return None
|
||||
elif a.type == "as_memory_format":
|
||||
@ -2536,10 +2646,10 @@ def _canonicalize_graph(
|
||||
return None
|
||||
if isinstance(a, TensorArgument):
|
||||
return a.name
|
||||
elif isinstance(a, (SymIntArgument, SymBoolArgument)):
|
||||
elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
|
||||
if a.type == "as_name":
|
||||
return a.as_name
|
||||
elif a.type in ("as_int", "as_bool"):
|
||||
elif a.type in ("as_int", "as_bool", "as_float"):
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(f"Unknown argument type: {a}")
|
||||
@ -2654,6 +2764,9 @@ def _canonicalize_graph(
|
||||
elif isinstance(a, SymIntArgument):
|
||||
if a.type == "as_name":
|
||||
a.as_name = _rename(a.as_name, graph.sym_int_values)
|
||||
elif isinstance(a, SymFloatArgument):
|
||||
if a.type == "as_name":
|
||||
a.as_name = _rename(a.as_name, graph.sym_float_values)
|
||||
elif isinstance(a, SymBoolArgument):
|
||||
if a.type == "as_name":
|
||||
a.as_name = _rename(a.as_name, graph.sym_bool_values)
|
||||
@ -2665,7 +2778,7 @@ def _canonicalize_graph(
|
||||
return
|
||||
if isinstance(a, TensorArgument):
|
||||
a.name = name_table.get(a.name, a.name)
|
||||
elif isinstance(a, SymIntArgument):
|
||||
elif isinstance(a, (SymIntArgument, SymFloatArgument)):
|
||||
if a.type == "as_name":
|
||||
a.as_name = name_table.get(a.as_name, a.as_name)
|
||||
elif isinstance(a, SymBoolArgument):
|
||||
@ -2700,6 +2813,9 @@ def _canonicalize_graph(
|
||||
sorted_sym_int_values = dict(
|
||||
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
sorted_sym_float_values = dict(
|
||||
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
sorted_sym_bool_values = dict(
|
||||
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
@ -2722,6 +2838,7 @@ def _canonicalize_graph(
|
||||
nodes=sorted_nodes,
|
||||
tensor_values=sorted_tensor_values,
|
||||
sym_int_values=sorted_sym_int_values,
|
||||
sym_float_values=sorted_sym_float_values,
|
||||
sym_bool_values=sorted_sym_bool_values,
|
||||
is_single_tensor_return=graph.is_single_tensor_return,
|
||||
)
|
||||
@ -2832,6 +2949,14 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unknown sym_int type: {s}")
|
||||
elif arg.type == "as_sym_float":
|
||||
f = arg.as_sym_float
|
||||
if f.type == "as_name":
|
||||
f.as_name = replace_table[f.as_name]
|
||||
elif f.type == "as_float":
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unknown sym_float type: {f}")
|
||||
elif arg.type in (
|
||||
"as_none",
|
||||
"as_bool",
|
||||
@ -2877,6 +3002,14 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unknown sym_int type: {s}")
|
||||
elif arg.type == "as_sym_float":
|
||||
f = arg.as_sym_float
|
||||
if f.type == "as_name":
|
||||
f.as_name = replace_table[f.as_name]
|
||||
elif f.type == "as_float":
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unknown sym_float type: {f}")
|
||||
elif arg.type in ("as_none", "as_int", "as_float", "as_string"):
|
||||
return
|
||||
else:
|
||||
|
@ -12,6 +12,7 @@ from torch.export.graph_signature import (
|
||||
CustomObjArgument,
|
||||
InputKind,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
TensorArgument,
|
||||
TokenArgument,
|
||||
@ -310,7 +311,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
)
|
||||
|
||||
for input_spec, node in zip(gs.input_specs, input_node_names):
|
||||
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymBoolArgument)):
|
||||
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument)):
|
||||
if input_spec.arg.name != node:
|
||||
raise SpecViolationError(
|
||||
f"Input spec name {input_spec.arg.name} does not match node name {node}"
|
||||
|
193
torch/csrc/utils/generated_serialization_types.h
generated
193
torch/csrc/utils/generated_serialization_types.h
generated
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
|
||||
// checksum<<b2d7665a2d5d77eca43ac97af5e691123dd82b7b2582b8e81f2c326761e2f649>>
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
@ -118,6 +118,8 @@ class SymBool;
|
||||
class SymBoolArgument;
|
||||
class SymExpr;
|
||||
class SymExprHint;
|
||||
class SymFloat;
|
||||
class SymFloatArgument;
|
||||
class SymInt;
|
||||
class SymIntArgument;
|
||||
class TensorArgument;
|
||||
@ -320,6 +322,58 @@ class SymInt {
|
||||
}
|
||||
};
|
||||
|
||||
class SymFloat {
|
||||
struct Void {};
|
||||
|
||||
public:
|
||||
enum class Tag {
|
||||
AS_EXPR, AS_INT
|
||||
};
|
||||
|
||||
private:
|
||||
std::variant<Void, SymExpr, double> variant_;
|
||||
Tag tag_;
|
||||
|
||||
public:
|
||||
Tag tag() const {
|
||||
return tag_;
|
||||
}
|
||||
|
||||
const SymExpr& get_as_expr() const {
|
||||
return std::get<1>(variant_);
|
||||
}
|
||||
|
||||
const double& get_as_int() const {
|
||||
return std::get<2>(variant_);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloat& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_EXPR) {
|
||||
nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_INT) {
|
||||
nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloat& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_j.contains("as_expr")) {
|
||||
nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get<SymExpr>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_EXPR;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_int")) {
|
||||
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get<double>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_INT;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SymBool {
|
||||
struct Void {};
|
||||
|
||||
@ -468,6 +522,58 @@ class SymIntArgument {
|
||||
}
|
||||
};
|
||||
|
||||
class SymFloatArgument {
|
||||
struct Void {};
|
||||
|
||||
public:
|
||||
enum class Tag {
|
||||
AS_NAME, AS_FLOAT
|
||||
};
|
||||
|
||||
private:
|
||||
std::variant<Void, std::string, double> variant_;
|
||||
Tag tag_;
|
||||
|
||||
public:
|
||||
Tag tag() const {
|
||||
return tag_;
|
||||
}
|
||||
|
||||
const std::string& get_as_name() const {
|
||||
return std::get<1>(variant_);
|
||||
}
|
||||
|
||||
const double& get_as_float() const {
|
||||
return std::get<2>(variant_);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloatArgument& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_NAME) {
|
||||
nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
|
||||
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloatArgument& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_j.contains("as_name")) {
|
||||
nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get<std::string>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_NAME;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_float")) {
|
||||
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get<double>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SymBoolArgument {
|
||||
struct Void {};
|
||||
|
||||
@ -643,11 +749,11 @@ class Argument {
|
||||
|
||||
public:
|
||||
enum class Tag {
|
||||
AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR
|
||||
AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR
|
||||
};
|
||||
|
||||
private:
|
||||
std::variant<Void, std::tuple<>, TensorArgument, std::vector<TensorArgument>, int64_t, std::vector<int64_t>, double, std::vector<double>, std::string, std::vector<std::string>, SymIntArgument, std::vector<SymIntArgument>, ScalarType, MemoryFormat, Layout, Device, bool, std::vector<bool>, SymBoolArgument, std::vector<SymBoolArgument>, GraphArgument, std::vector<OptionalTensorArgument>, CustomObjArgument, std::string> variant_;
|
||||
std::variant<Void, std::tuple<>, TensorArgument, std::vector<TensorArgument>, int64_t, std::vector<int64_t>, double, std::vector<double>, std::string, std::vector<std::string>, SymIntArgument, std::vector<SymIntArgument>, SymFloatArgument, std::vector<SymFloatArgument>, ScalarType, MemoryFormat, Layout, Device, bool, std::vector<bool>, SymBoolArgument, std::vector<SymBoolArgument>, GraphArgument, std::vector<OptionalTensorArgument>, CustomObjArgument, std::string> variant_;
|
||||
Tag tag_;
|
||||
|
||||
public:
|
||||
@ -699,54 +805,62 @@ class Argument {
|
||||
return std::get<11>(variant_);
|
||||
}
|
||||
|
||||
const ScalarType& get_as_scalar_type() const {
|
||||
const SymFloatArgument& get_as_sym_float() const {
|
||||
return std::get<12>(variant_);
|
||||
}
|
||||
|
||||
const MemoryFormat& get_as_memory_format() const {
|
||||
const std::vector<SymFloatArgument>& get_as_sym_floats() const {
|
||||
return std::get<13>(variant_);
|
||||
}
|
||||
|
||||
const Layout& get_as_layout() const {
|
||||
const ScalarType& get_as_scalar_type() const {
|
||||
return std::get<14>(variant_);
|
||||
}
|
||||
|
||||
const Device& get_as_device() const {
|
||||
const MemoryFormat& get_as_memory_format() const {
|
||||
return std::get<15>(variant_);
|
||||
}
|
||||
|
||||
const bool& get_as_bool() const {
|
||||
const Layout& get_as_layout() const {
|
||||
return std::get<16>(variant_);
|
||||
}
|
||||
|
||||
const std::vector<bool>& get_as_bools() const {
|
||||
const Device& get_as_device() const {
|
||||
return std::get<17>(variant_);
|
||||
}
|
||||
|
||||
const SymBoolArgument& get_as_sym_bool() const {
|
||||
const bool& get_as_bool() const {
|
||||
return std::get<18>(variant_);
|
||||
}
|
||||
|
||||
const std::vector<SymBoolArgument>& get_as_sym_bools() const {
|
||||
const std::vector<bool>& get_as_bools() const {
|
||||
return std::get<19>(variant_);
|
||||
}
|
||||
|
||||
const GraphArgument& get_as_graph() const {
|
||||
const SymBoolArgument& get_as_sym_bool() const {
|
||||
return std::get<20>(variant_);
|
||||
}
|
||||
|
||||
const std::vector<OptionalTensorArgument>& get_as_optional_tensors() const {
|
||||
const std::vector<SymBoolArgument>& get_as_sym_bools() const {
|
||||
return std::get<21>(variant_);
|
||||
}
|
||||
|
||||
const CustomObjArgument& get_as_custom_obj() const {
|
||||
const GraphArgument& get_as_graph() const {
|
||||
return std::get<22>(variant_);
|
||||
}
|
||||
|
||||
const std::string& get_as_operator() const {
|
||||
const std::vector<OptionalTensorArgument>& get_as_optional_tensors() const {
|
||||
return std::get<23>(variant_);
|
||||
}
|
||||
|
||||
const CustomObjArgument& get_as_custom_obj() const {
|
||||
return std::get<24>(variant_);
|
||||
}
|
||||
|
||||
const std::string& get_as_operator() const {
|
||||
return std::get<25>(variant_);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_NONE) {
|
||||
@ -793,6 +907,14 @@ class Argument {
|
||||
nlohmann_json_j["as_sym_ints"] = nlohmann_json_t.get_as_sym_ints();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOAT) {
|
||||
nlohmann_json_j["as_sym_float"] = nlohmann_json_t.get_as_sym_float();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOATS) {
|
||||
nlohmann_json_j["as_sym_floats"] = nlohmann_json_t.get_as_sym_floats();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_SCALAR_TYPE) {
|
||||
nlohmann_json_j["as_scalar_type"] = nlohmann_json_t.get_as_scalar_type();
|
||||
return;
|
||||
@ -900,63 +1022,73 @@ class Argument {
|
||||
nlohmann_json_t.tag_ = Tag::AS_SYM_INTS;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_sym_float")) {
|
||||
nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_sym_float").template get<SymFloatArgument>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_SYM_FLOAT;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_sym_floats")) {
|
||||
nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_sym_floats").template get<std::vector<SymFloatArgument>>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_SYM_FLOATS;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_scalar_type")) {
|
||||
nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_scalar_type").template get<ScalarType>());
|
||||
nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_scalar_type").template get<ScalarType>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_SCALAR_TYPE;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_memory_format")) {
|
||||
nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_memory_format").template get<MemoryFormat>());
|
||||
nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_memory_format").template get<MemoryFormat>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_MEMORY_FORMAT;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_layout")) {
|
||||
nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_layout").template get<Layout>());
|
||||
nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_layout").template get<Layout>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_LAYOUT;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_device")) {
|
||||
nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_device").template get<Device>());
|
||||
nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_device").template get<Device>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_DEVICE;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_bool")) {
|
||||
nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_bool").template get<bool>());
|
||||
nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_bool").template get<bool>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_BOOL;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_bools")) {
|
||||
nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_bools").template get<std::vector<bool>>());
|
||||
nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_bools").template get<std::vector<bool>>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_BOOLS;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_sym_bool")) {
|
||||
nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_sym_bool").template get<SymBoolArgument>());
|
||||
nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_sym_bool").template get<SymBoolArgument>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_SYM_BOOL;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_sym_bools")) {
|
||||
nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_sym_bools").template get<std::vector<SymBoolArgument>>());
|
||||
nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_sym_bools").template get<std::vector<SymBoolArgument>>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_SYM_BOOLS;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_graph")) {
|
||||
nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_graph").template get<GraphArgument>());
|
||||
nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_graph").template get<GraphArgument>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_GRAPH;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_optional_tensors")) {
|
||||
nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_optional_tensors").template get<std::vector<OptionalTensorArgument>>());
|
||||
nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_optional_tensors").template get<std::vector<OptionalTensorArgument>>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSORS;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_custom_obj")) {
|
||||
nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_custom_obj").template get<CustomObjArgument>());
|
||||
nlohmann_json_t.variant_.emplace<24>(nlohmann_json_j.at("as_custom_obj").template get<CustomObjArgument>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_CUSTOM_OBJ;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_operator")) {
|
||||
nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_operator").template get<std::string>());
|
||||
nlohmann_json_t.variant_.emplace<25>(nlohmann_json_j.at("as_operator").template get<std::string>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_OPERATOR;
|
||||
return;
|
||||
}
|
||||
@ -1021,6 +1153,7 @@ class Graph {
|
||||
std::unordered_map<std::string, SymBool> sym_bool_values;
|
||||
bool is_single_tensor_return = false;
|
||||
std::unordered_map<std::string, CustomObjArgument> custom_obj_values = {};
|
||||
std::unordered_map<std::string, SymFloat> sym_float_values = {};
|
||||
|
||||
public:
|
||||
|
||||
@ -1056,6 +1189,10 @@ class Graph {
|
||||
return custom_obj_values;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, SymFloat>& get_sym_float_values() const {
|
||||
return sym_float_values;
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t);
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t);
|
||||
};
|
||||
@ -1892,6 +2029,7 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_
|
||||
nlohmann_json_j["sym_bool_values"] = nlohmann_json_t.sym_bool_values;
|
||||
nlohmann_json_j["is_single_tensor_return"] = nlohmann_json_t.is_single_tensor_return;
|
||||
nlohmann_json_j["custom_obj_values"] = nlohmann_json_t.custom_obj_values;
|
||||
nlohmann_json_j["sym_float_values"] = nlohmann_json_t.sym_float_values;
|
||||
}
|
||||
|
||||
inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t) {
|
||||
@ -1904,6 +2042,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_jso
|
||||
nlohmann_json_t.sym_bool_values = nlohmann_json_j.value("sym_bool_values", nlohmann_json_default_obj.sym_bool_values);
|
||||
nlohmann_json_t.is_single_tensor_return = nlohmann_json_j.value("is_single_tensor_return", nlohmann_json_default_obj.is_single_tensor_return);
|
||||
nlohmann_json_t.custom_obj_values = nlohmann_json_j.value("custom_obj_values", nlohmann_json_default_obj.custom_obj_values);
|
||||
nlohmann_json_t.sym_float_values = nlohmann_json_j.value("sym_float_values", nlohmann_json_default_obj.sym_float_values);
|
||||
}
|
||||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t) {
|
||||
|
@ -80,6 +80,7 @@ from .graph_signature import ( # noqa: F401
|
||||
OutputKind,
|
||||
OutputSpec,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
SymIntArgument,
|
||||
TensorArgument,
|
||||
TokenArgument,
|
||||
@ -521,6 +522,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
return TensorArgument(name=new_ph.name)
|
||||
elif isinstance(old_arg, SymIntArgument):
|
||||
return SymIntArgument(name=new_ph.name)
|
||||
elif isinstance(old_arg, SymFloatArgument):
|
||||
return SymFloatArgument(name=new_ph.name)
|
||||
elif isinstance(old_arg, SymBoolArgument):
|
||||
return SymBoolArgument(name=new_ph.name)
|
||||
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
|
||||
|
@ -20,6 +20,7 @@ __all__ = [
|
||||
"OutputKind",
|
||||
"OutputSpec",
|
||||
"SymIntArgument",
|
||||
"SymFloatArgument",
|
||||
"SymBoolArgument",
|
||||
"TensorArgument",
|
||||
]
|
||||
@ -40,6 +41,11 @@ class SymIntArgument:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymFloatArgument:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymBoolArgument:
|
||||
name: str
|
||||
@ -61,6 +67,7 @@ class ConstantArgument:
|
||||
ArgumentSpec = Union[
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
ConstantArgument,
|
||||
CustomObjArgument,
|
||||
@ -94,6 +101,7 @@ class InputSpec:
|
||||
(
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
ConstantArgument,
|
||||
CustomObjArgument,
|
||||
@ -124,6 +132,7 @@ class OutputSpec:
|
||||
(
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
ConstantArgument,
|
||||
TokenArgument,
|
||||
@ -273,7 +282,13 @@ class ExportGraphSignature:
|
||||
|
||||
if isinstance(
|
||||
s.arg,
|
||||
(TensorArgument, SymIntArgument, SymBoolArgument, CustomObjArgument),
|
||||
(
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
CustomObjArgument,
|
||||
),
|
||||
):
|
||||
user_inputs.append(s.arg.name)
|
||||
elif isinstance(s.arg, ConstantArgument):
|
||||
@ -294,7 +309,10 @@ class ExportGraphSignature:
|
||||
]:
|
||||
continue
|
||||
|
||||
if isinstance(s.arg, (TensorArgument, SymIntArgument, SymBoolArgument)):
|
||||
if isinstance(
|
||||
s.arg,
|
||||
(TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
|
||||
):
|
||||
user_outputs.append(s.arg.name)
|
||||
elif isinstance(s.arg, ConstantArgument):
|
||||
user_outputs.append(s.arg.value)
|
||||
@ -444,6 +462,7 @@ class ExportGraphSignature:
|
||||
arg_types = (
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
CustomObjArgument,
|
||||
TokenArgument,
|
||||
@ -478,7 +497,7 @@ def _immutable_dict(items):
|
||||
|
||||
|
||||
def _make_argument_spec(node, token_names) -> ArgumentSpec:
|
||||
from torch import ScriptObject, SymBool, SymInt
|
||||
from torch import ScriptObject, SymBool, SymFloat, SymInt
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
@ -496,6 +515,8 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
|
||||
return TensorArgument(name=node.name)
|
||||
elif isinstance(val, SymInt):
|
||||
return SymIntArgument(name=node.name)
|
||||
elif isinstance(val, SymFloat):
|
||||
return SymFloatArgument(name=node.name)
|
||||
elif isinstance(val, SymBool):
|
||||
return SymBoolArgument(name=node.name)
|
||||
elif isinstance(val, ScriptObject):
|
||||
|
@ -23,6 +23,7 @@ from torch.export.exported_program import (
|
||||
InputKind,
|
||||
ModuleCallSignature,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
SymIntArgument,
|
||||
TensorArgument,
|
||||
)
|
||||
@ -1023,7 +1024,13 @@ class _ModuleFrame:
|
||||
input_nodes.append(None)
|
||||
else:
|
||||
assert isinstance(
|
||||
input, (TensorArgument, SymIntArgument, SymBoolArgument)
|
||||
input,
|
||||
(
|
||||
TensorArgument,
|
||||
SymIntArgument,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
),
|
||||
)
|
||||
input_nodes.append(
|
||||
self.parent.remap_input(self.seen_nodes[input.name])
|
||||
@ -1133,7 +1140,8 @@ class _ModuleFrame:
|
||||
if signature is not None and self.parent is not None:
|
||||
for output in signature.outputs:
|
||||
if isinstance(
|
||||
output, (TensorArgument, SymIntArgument, SymBoolArgument)
|
||||
output,
|
||||
(TensorArgument, SymIntArgument, SymBoolArgument, SymFloatArgument),
|
||||
):
|
||||
if output.name in self.seen_nodes:
|
||||
orig_outputs.append(self.seen_nodes[output.name])
|
||||
|
@ -130,6 +130,7 @@ __all__ = [
|
||||
"create_contiguous",
|
||||
"ShapeEnv",
|
||||
"is_concrete_int",
|
||||
"is_concrete_float",
|
||||
"guard_int",
|
||||
"guard_float",
|
||||
"guard_scalar",
|
||||
@ -367,6 +368,25 @@ def is_concrete_int(a: Union[int, SymInt]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_concrete_float(a: Union[float, SymFloat]) -> bool:
|
||||
r"""Utility to check if underlying object
|
||||
in SymInt is concrete value. Also returns
|
||||
true if integer is passed in.
|
||||
|
||||
Args:
|
||||
a (SymInt or float): Object to test if it float
|
||||
"""
|
||||
assert isinstance(a, (SymFloat, float))
|
||||
|
||||
if isinstance(a, float):
|
||||
return True
|
||||
|
||||
if isinstance(a.node.expr, sympy.core.numbers.Float):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
|
||||
"""
|
||||
Perform a guard on a symbolic boolean expression in a size oblivious way.
|
||||
|
Reference in New Issue
Block a user