Add truediv support in export serializer (#136364)

Fixes #136113

- [x] Inital `truediv` coverage
- [ ] Expand/reduce coverage?
- [x] Add tests
- [x] Re-check docstrings
- [ ] Linting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136364
Approved by: https://github.com/pianpwk

Co-authored-by: Angela Yi <angelayi@meta.com>
Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This commit is contained in:
bhack
2024-11-27 00:31:45 +00:00
committed by PyTorch MergeBot
parent 9b89fa44ba
commit 1df440dc4e
15 changed files with 470 additions and 60 deletions

View File

@ -890,6 +890,7 @@ API Reference
.. autoclass:: OutputSpec
.. autoclass:: SymIntArgument
.. autoclass:: SymBoolArgument
.. autoclass:: SymFloatArgument
.. autoclass:: ExportGraphSignature
.. automethod:: replace_all_uses

View File

@ -870,6 +870,17 @@ graph():
ep_res = torch.export.export(M(), input).module()(*input)
self.assertEqual(orig_res, ep_res)
def test_symfloat_item(self):
class M(torch.nn.Module):
def forward(self, tensor):
return tensor.item()
input = (torch.tensor([3.14], dtype=torch.float),)
orig_res = M()(*input)
ep_res = torch.export.export(M(), input).module()(*input)
self.assertEqual(orig_res, ep_res)
def test_unbacked_to_cond(self):
class M(torch.nn.Module):
def forward(self, a):
@ -6544,8 +6555,6 @@ graph():
ep = export(m, inputs)
self.assertEqual(ep.module()(*inputs), m(*inputs))
@testing.expectedFailureSerDer # symfloat nyi
@testing.expectedFailureSerDerNonStrict
def test_sym_sqrt(self):
import math

View File

@ -322,6 +322,31 @@ def forward(self, x):
self.assertEqual(node.inputs[0].name, "self")
self.assertEqual(node.inputs[1].name, "dim")
def test_serialize_sym_float(self) -> None:
class DynamicFloatSimpleModel(torch.nn.Module):
def __init__(self, multiplier: torch.SymFloat):
super().__init__()
self.multiplier = multiplier
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
e = d * self.multiplier
e_s0 = e.shape[0]
e_s1 = e.shape[1]
e_s3 = e_s0 * e_s1
f = e.view(e_s3)
return torch.cat([f, f])
multiplier_sym = torch.SymFloat("multiplier_sym")
model = DynamicFloatSimpleModel(multiplier_sym)
inputs = (
torch.randn(2, 4),
torch.randn(4, 7),
torch.randn(2, 7),
)
dim0_ac = Dim("dim0_ac")
dim1_bc = Dim("dim1_b")
def test_serialize_infinite_sym_int(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):

View File

@ -2712,6 +2712,14 @@ class AOTInductorTestsTemplate:
inputs = (torch.tensor([0], dtype=torch.bool, device=self.device),)
self.check_model(Model(), inputs)
def test_symfloat_item(self):
class Model(torch.nn.Module):
def forward(self, tensor):
return tensor.item()
inputs = (torch.tensor([3.14], dtype=torch.float, device=self.device),)
self.check_model(Model(), inputs)
def test_constant_original_fqn_and_dtype(self):
class FooBarModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -166,6 +166,8 @@ CPU_TEST_FAILURES = {
"test_symint_item": fail_minimal_arrayref_interface(is_skip=True),
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
"test_symfloat_item": fail_minimal_arrayref_interface(is_skip=True),
"test_issue_140766": fail_minimal_arrayref_interface(),
}

View File

@ -37,6 +37,7 @@ from torch.export.graph_signature import (
OutputSpec,
SymIntArgument,
SymBoolArgument,
SymFloatArgument,
TensorArgument,
)
from torch.fx import traceback as fx_traceback

View File

@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 1)
SCHEMA_VERSION = (8, 2)
TREESPEC_VERSION = 1
@ -77,6 +77,11 @@ class SymInt(_Union):
as_expr: SymExpr
as_int: int
@dataclass(repr=False)
class SymFloat(_Union):
as_expr: SymExpr
as_int: float
@dataclass(repr=False)
class SymBool(_Union):
@ -106,6 +111,16 @@ class SymIntArgument(_Union):
as_name: str
as_int: int
# In most cases we will use the "as_name" field to store arguments which are
# SymFloats.
# The "as_float" field is used in the case where we have a list containing a mix
# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to
# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints
# to the "as_float" field.
@dataclass(repr=False)
class SymFloatArgument(_Union):
as_name: str
as_float: float
# In most cases we will use the "as_name" field to store arguments which are
# SymBools.
@ -165,6 +180,8 @@ class Argument(_Union):
as_strings: List[str]
as_sym_int: SymIntArgument
as_sym_ints: List[SymIntArgument]
as_sym_float: SymFloatArgument
as_sym_floats: List[SymFloatArgument]
as_scalar_type: ScalarType
as_memory_format: MemoryFormat
as_layout: Layout
@ -208,6 +225,7 @@ class Graph:
# list.
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
@dataclass

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
# checksum<<b2d7665a2d5d77eca43ac97af5e691123dd82b7b2582b8e81f2c326761e2f649>>
Argument:
kind: union
fields:
@ -25,6 +25,10 @@ Argument:
type: SymIntArgument
as_sym_ints:
type: List[SymIntArgument]
as_sym_float:
type: SymFloatArgument
as_sym_floats:
type: List[SymFloatArgument]
as_scalar_type:
type: ScalarType
as_memory_format:
@ -136,6 +140,9 @@ Graph:
custom_obj_values:
type: Dict[str, CustomObjArgument]
default: '{}'
sym_float_values:
type: Dict[str, SymFloat]
default: '{}'
GraphArgument:
kind: struct
fields:
@ -377,6 +384,20 @@ SymExprHint:
type: float
as_bool:
type: bool
SymFloat:
kind: union
fields:
as_expr:
type: SymExpr
as_int:
type: float
SymFloatArgument:
kind: union
fields:
as_name:
type: str
as_float:
type: float
SymInt:
kind: union
fields:
@ -437,5 +458,5 @@ UserOutputSpec:
type: Argument
SCHEMA_VERSION:
- 8
- 1
- 2
TREESPEC_VERSION: 1

View File

@ -87,6 +87,8 @@ from .schema import ( # type: ignore[attr-defined]
SymExprHint,
SymInt,
SymIntArgument,
SymFloat,
SymFloatArgument,
TensorArgument,
TensorMeta,
TokenArgument,
@ -118,7 +120,7 @@ def _reverse_map(d: Dict[Any, Enum]):
MetaType = Union[
FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument
FakeTensor, int, torch.SymInt, float, torch.SymFloat, bool, torch.SymBool, ep.CustomObjArgument
]
@ -169,6 +171,9 @@ _TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type]
_SYM_FLOAT_OPS = {
operator.truediv,
}
_SYM_INT_OPS = {
operator.mul,
@ -201,7 +206,7 @@ _SYM_BOOL_OPS = {
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_INT_OPS)
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_BOOL_OPS)
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_FLOAT_OPS)
@dataclass
class SerializedArtifact:
@ -251,6 +256,25 @@ def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
)
def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat:
if isinstance(s, (torch.SymFloat, sympy.Symbol, float)):
if symbolic_shapes.is_concrete_float(s):
return SymFloat.create(as_float=float(s))
else:
assert isinstance(s, (torch.SymFloat, sympy.Symbol))
if s.node.hint is None:
return SymFloat.create(as_expr=SymExpr(_print_sympy(s)))
else:
return SymFloat.create(
as_expr=SymExpr(
_print_sympy(s),
hint=SymExprHint.create(as_float=s.node.hint),
)
)
else:
raise SerializeError(
f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`"
)
def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
if isinstance(s, (torch.SymBool, bool)):
@ -425,6 +449,7 @@ class GraphState:
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
@ -468,6 +493,8 @@ class GraphModuleSerializer(metaclass=Final):
)
elif isinstance(node.meta["val"], torch.SymInt):
raise AssertionError("SymInt graph input is not implemented yet.")
elif isinstance(node.meta["val"], torch.SymFloat):
raise AssertionError("SymFloat graph input is not implemented yet.")
elif isinstance(node.meta["val"], (int, bool, str, float, type(None))):
graph_input = self.serialize_input(node.meta["val"])
elif isinstance(node.meta["val"], ep.CustomObjArgument):
@ -516,21 +543,25 @@ class GraphModuleSerializer(metaclass=Final):
if (
node.target in _SYM_INT_OPS
or node.target in _SYM_BOOL_OPS
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool)))
or node.target in _SYM_FLOAT_OPS
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)))
):
assert len(node.kwargs) == 0
# Serialize the node
if isinstance(meta_val, torch.SymInt):
sym_output = Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))
elif isinstance(meta_val, torch.SymFloat):
sym_output = Argument.create(as_sym_float=self.serialize_sym_float_output(node.name, meta_val))
elif isinstance(meta_val, torch.SymBool):
sym_output = Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))
else:
raise SerializeError(f"Unsupported symbolic type: {type(meta_val)}")
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Argument.create(
as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
)
if (node.target in _SYM_INT_OPS or isinstance(meta_val, torch.SymInt))
else Argument.create(
as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
)
],
outputs=[sym_output],
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.OpOverload):
@ -645,7 +676,7 @@ class GraphModuleSerializer(metaclass=Final):
if isinstance(op, torch._ops.OpOverload):
args_names = [arg.name for arg in op._schema.arguments]
else:
assert op in _SYM_INT_OPS or op in _SYM_BOOL_OPS
assert op in _SYM_INT_OPS or op in _SYM_BOOL_OPS or op in _SYM_FLOAT_OPS
args_names = list(inspect.signature(op).parameters.keys())
serialized_args = []
for args_name, arg in zip(args_names, args):
@ -713,6 +744,12 @@ class GraphModuleSerializer(metaclass=Final):
and arg.name in self.graph_state.sym_int_values
)
def is_sym_float_arg(self, arg) -> bool:
return isinstance(arg, float) or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_float_values
)
def is_sym_bool_arg(self, arg) -> bool:
return isinstance(arg, bool) or (
isinstance(arg, torch.fx.Node)
@ -752,6 +789,10 @@ class GraphModuleSerializer(metaclass=Final):
return Argument.create(
as_sym_int=SymIntArgument.create(as_name=arg.name)
)
elif self.is_sym_float_arg(arg):
return Argument.create(
as_sym_float=SymFloatArgument.create(as_name=arg.name)
)
elif self.is_sym_bool_arg(arg):
return Argument.create(
as_sym_bool=SymBoolArgument.create(as_name=arg.name)
@ -781,6 +822,12 @@ class GraphModuleSerializer(metaclass=Final):
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_int_arg(arg) being true
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
elif isinstance(arg, torch.SymFloat):
# This is a special branch for handling SymFloat args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_float_arg(arg) being true
return Argument.create(as_sym_float=SymFloatArgument.create(as_name=str(arg)))
elif type(arg) is bool:
return Argument.create(as_bool=arg)
elif type(arg) is str:
@ -841,6 +888,10 @@ class GraphModuleSerializer(metaclass=Final):
return Argument.create(
as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg]
)
elif all(isinstance(a, torch.SymFloat) for a in arg):
return Argument.create(
as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg]
)
elif all(self.is_sym_int_arg(a) for a in arg):
# list of sym_ints
values = []
@ -850,6 +901,15 @@ class GraphModuleSerializer(metaclass=Final):
elif type(a) is int:
values.append(SymIntArgument.create(as_int=a))
return Argument.create(as_sym_ints=values)
elif all(self.is_sym_float_arg(a) for a in arg):
# list of sym_float
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymFloatArgument.create(as_name=a.name))
elif isinstance(a, float):
values.append(SymFloatArgument.create(as_float=a))
return Argument.create(as_sym_floats=values)
elif all(self.is_sym_bool_arg(a) for a in arg):
# list of sym_bools
values = []
@ -954,6 +1014,11 @@ class GraphModuleSerializer(metaclass=Final):
self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
return SymIntArgument.create(as_name=name)
def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument:
assert name not in self.graph_state.sym_float_values
self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val)
return SymFloatArgument.create(as_name=name)
def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
assert name not in self.graph_state.sym_bool_values
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
@ -1102,6 +1167,8 @@ class GraphModuleSerializer(metaclass=Final):
return Argument.create(as_tensor=TensorArgument(name=x.name))
elif isinstance(x, ep.SymIntArgument):
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
elif isinstance(x, ep.SymFloatArgument):
return Argument.create(as_sym_Float=SymFloatArgument.create(as_name=x.name))
elif isinstance(x, ep.ConstantArgument):
return self.serialize_input(x.value)
elif isinstance(x, ep.CustomObjArgument):
@ -1219,7 +1286,7 @@ class GraphModuleSerializer(metaclass=Final):
sub_user_node_name = self._output_node_name_at_index(user_node, i)
args.append(self.serialize_tensor_output(sub_user_node_name, m))
output_arguments.append(Argument.create(as_tensors=args))
elif isinstance(meta, (int, SymInt)):
elif isinstance(meta, (int, SymInt, float, SymFloat)):
user_node_name = self._output_node_name_at_index(node, idx)
output_arguments.append(self.serialize_output(user_node_name, meta))
else:
@ -1291,6 +1358,11 @@ class GraphModuleSerializer(metaclass=Final):
return Argument.create(
as_sym_int=self.serialize_sym_int_output(name, meta_val)
)
elif isinstance(meta_val, (int, torch.SymFloat)):
# e.g "-> SymFloat"
return Argument.create(
as_sym_float=self.serialize_sym_float_output(name, meta_val)
)
elif isinstance(meta_val, torch.SymBool):
# e.g "-> SymBool"
return Argument.create(
@ -1340,6 +1412,7 @@ class GraphModuleSerializer(metaclass=Final):
nodes=self.graph_state.nodes,
tensor_values=self.graph_state.tensor_values,
sym_int_values=self.graph_state.sym_int_values,
sym_float_values=self.graph_state.sym_float_values,
sym_bool_values=self.graph_state.sym_bool_values,
custom_obj_values=self.graph_state.custom_obj_values,
outputs=self.graph_state.outputs,
@ -1557,6 +1630,20 @@ class GraphModuleDeserializer(metaclass=Final):
f"SymInt has invalid field type {s.type} with value {s.value}"
)
def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]:
val = s.value
if s.type == "as_expr":
hint = val.hint.as_float if val.hint else None
sym = self._parse_sym_expr(val.expr_str, hint)
return self.shape_env.create_symfloatnode(sym, hint=hint)
elif s.type == "as_float":
assert isinstance(val, float)
return val
else:
raise SerializeError(
f"SymFloat has invalid field type {s.type} with value {s.value}"
)
def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
val = s.value
if s.type == "as_expr":
@ -1600,8 +1687,12 @@ class GraphModuleDeserializer(metaclass=Final):
return self.serialized_name_to_node[output.as_sym_int.as_name]
elif output.type == "as_sym_bool":
return self.serialized_name_to_node[output.as_sym_bool.as_name]
elif output.type == "as_sym_float":
return self.serialized_name_to_node[output.as_sym_float.as_name]
elif output.type == "as_int":
return output.as_int
elif output.type == "as_float":
return output.as_float
elif output.type == "as_none":
return None
else:
@ -1616,6 +1707,9 @@ class GraphModuleDeserializer(metaclass=Final):
for name, sym_int_value in serialized_graph.sym_int_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
for name, sym_int_value in serialized_graph.sym_float_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_float(sym_int_value)
for name, sym_bool_value in serialized_graph.sym_bool_values.items():
self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
sym_bool_value
@ -1628,7 +1722,7 @@ class GraphModuleDeserializer(metaclass=Final):
# Inputs: convert to placeholder nodes in FX.
for i, input_ in enumerate(serialized_graph.inputs):
if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"):
if input_.type in ("as_tensor", "as_sym_int", "as_sym_float", "as_custom_obj"):
node_name = input_.value.name
placeholder_node = self.graph.placeholder(node_name)
# FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
@ -1684,6 +1778,7 @@ class GraphModuleDeserializer(metaclass=Final):
if (
target in _SYM_BOOL_OPS
or target in _SYM_INT_OPS
or target in _SYM_FLOAT_OPS
or target == torch.ops.aten.item.default # this can produce either SymInt or SymBool
):
name = serialized_node.outputs[0].value.as_name
@ -2021,6 +2116,8 @@ class GraphModuleDeserializer(metaclass=Final):
return inp.as_string
elif typ_ == "as_sym_int":
return self.deserialize_sym_argument(inp.as_sym_int)
elif typ_ == "as_sym_float":
return self.deserialize_sym_argument(inp.as_sym_float)
elif typ_ == "as_sym_bool":
return self.deserialize_sym_argument(inp.as_sym_bool)
elif isinstance(value, list):
@ -2032,7 +2129,7 @@ class GraphModuleDeserializer(metaclass=Final):
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
# convert from serialized.python.types.List to python list
return list(value)
elif typ_ in ("as_sym_ints", "as_sym_bools"):
elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"):
return [self.deserialize_sym_argument(arg) for arg in value]
elif typ_ == "as_optional_tensors":
@ -2077,6 +2174,11 @@ class GraphModuleDeserializer(metaclass=Final):
return sym_arg.as_int
elif sym_arg.type == "as_name":
return self.serialized_name_to_node[sym_arg.as_name]
elif isinstance(sym_arg, SymFloatArgument):
if sym_arg.type == "as_float":
return sym_arg.as_float
elif sym_arg.type == "as_name":
return self.serialized_name_to_node[sym_arg.as_name]
elif isinstance(sym_arg, SymBoolArgument):
if sym_arg.type == "as_bool":
return sym_arg.as_bool
@ -2098,7 +2200,7 @@ class GraphModuleDeserializer(metaclass=Final):
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
return
elif len(serialized_node.outputs) == 1 and isinstance(
serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)
serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)
):
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
return
@ -2113,13 +2215,15 @@ class GraphModuleDeserializer(metaclass=Final):
def generate_getitem(
meta_val,
fx_node: torch.fx.Node,
arg: Union[TensorArgument, SymIntArgument],
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
idx: int,
):
if isinstance(arg, TensorArgument):
name = arg.name
elif isinstance(arg, SymIntArgument):
name = arg.as_name
elif isinstance(arg, SymFloatArgument):
name = arg.as_name
else:
raise AssertionError(
f"generate_getitem got unknown argument type {type(arg)}"
@ -2140,7 +2244,7 @@ class GraphModuleDeserializer(metaclass=Final):
for idx, arg in enumerate(args):
if isinstance(arg, Argument):
arg = arg.value
if isinstance(arg, (TensorArgument, SymIntArgument)):
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
generate_getitem(meta_val, fx_node, arg, idx)
elif isinstance(arg, (list, tuple)):
list_output = self.graph.create_node(
@ -2241,6 +2345,8 @@ class GraphModuleDeserializer(metaclass=Final):
return ep.TensorArgument(name=x.as_tensor.name)
elif x.type == "as_sym_int":
return ep.SymIntArgument(name=x.as_sym_int.as_name)
elif x.type == "as_sym_float":
return ep.SymFloatArgument(name=x.as_sym_float.as_name)
elif x.type == "as_custom_obj":
return ep.ConstantArgument(name=x.as_custom_obj.name, value=self.deserialize_input(x))
else:
@ -2486,6 +2592,10 @@ def _canonicalize_graph(
return a.as_sym_int
elif a.type == "as_sym_ints":
return a.as_sym_ints
elif a.type == "as_sym_float":
return a.as_sym_float
elif a.type == "as_sym_floats":
return a.as_sym_floats
elif a.type == "as_scalar_type":
return None
elif a.type == "as_memory_format":
@ -2536,10 +2646,10 @@ def _canonicalize_graph(
return None
if isinstance(a, TensorArgument):
return a.name
elif isinstance(a, (SymIntArgument, SymBoolArgument)):
elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
if a.type == "as_name":
return a.as_name
elif a.type in ("as_int", "as_bool"):
elif a.type in ("as_int", "as_bool", "as_float"):
return None
else:
raise AssertionError(f"Unknown argument type: {a}")
@ -2654,6 +2764,9 @@ def _canonicalize_graph(
elif isinstance(a, SymIntArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_int_values)
elif isinstance(a, SymFloatArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_float_values)
elif isinstance(a, SymBoolArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_bool_values)
@ -2665,7 +2778,7 @@ def _canonicalize_graph(
return
if isinstance(a, TensorArgument):
a.name = name_table.get(a.name, a.name)
elif isinstance(a, SymIntArgument):
elif isinstance(a, (SymIntArgument, SymFloatArgument)):
if a.type == "as_name":
a.as_name = name_table.get(a.as_name, a.as_name)
elif isinstance(a, SymBoolArgument):
@ -2700,6 +2813,9 @@ def _canonicalize_graph(
sorted_sym_int_values = dict(
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
)
sorted_sym_float_values = dict(
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
)
sorted_sym_bool_values = dict(
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
)
@ -2722,6 +2838,7 @@ def _canonicalize_graph(
nodes=sorted_nodes,
tensor_values=sorted_tensor_values,
sym_int_values=sorted_sym_int_values,
sym_float_values=sorted_sym_float_values,
sym_bool_values=sorted_sym_bool_values,
is_single_tensor_return=graph.is_single_tensor_return,
)
@ -2832,6 +2949,14 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
pass
else:
raise AssertionError(f"Unknown sym_int type: {s}")
elif arg.type == "as_sym_float":
f = arg.as_sym_float
if f.type == "as_name":
f.as_name = replace_table[f.as_name]
elif f.type == "as_float":
pass
else:
raise AssertionError(f"Unknown sym_float type: {f}")
elif arg.type in (
"as_none",
"as_bool",
@ -2877,6 +3002,14 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
pass
else:
raise AssertionError(f"Unknown sym_int type: {s}")
elif arg.type == "as_sym_float":
f = arg.as_sym_float
if f.type == "as_name":
f.as_name = replace_table[f.as_name]
elif f.type == "as_float":
pass
else:
raise AssertionError(f"Unknown sym_float type: {f}")
elif arg.type in ("as_none", "as_int", "as_float", "as_string"):
return
else:

View File

@ -12,6 +12,7 @@ from torch.export.graph_signature import (
CustomObjArgument,
InputKind,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
TensorArgument,
TokenArgument,
@ -310,7 +311,7 @@ def _verify_exported_program_signature(exported_program) -> None:
)
for input_spec, node in zip(gs.input_specs, input_node_names):
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymBoolArgument)):
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument)):
if input_spec.arg.name != node:
raise SpecViolationError(
f"Input spec name {input_spec.arg.name} does not match node name {node}"

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
// checksum<<b2d7665a2d5d77eca43ac97af5e691123dd82b7b2582b8e81f2c326761e2f649>>
// clang-format off
#pragma once
@ -118,6 +118,8 @@ class SymBool;
class SymBoolArgument;
class SymExpr;
class SymExprHint;
class SymFloat;
class SymFloatArgument;
class SymInt;
class SymIntArgument;
class TensorArgument;
@ -320,6 +322,58 @@ class SymInt {
}
};
class SymFloat {
struct Void {};
public:
enum class Tag {
AS_EXPR, AS_INT
};
private:
std::variant<Void, SymExpr, double> variant_;
Tag tag_;
public:
Tag tag() const {
return tag_;
}
const SymExpr& get_as_expr() const {
return std::get<1>(variant_);
}
const double& get_as_int() const {
return std::get<2>(variant_);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloat& nlohmann_json_t) {
if (nlohmann_json_t.tag_ == Tag::AS_EXPR) {
nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_INT) {
nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int();
return;
}
}
friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloat& nlohmann_json_t) {
if (nlohmann_json_j.contains("as_expr")) {
nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get<SymExpr>());
nlohmann_json_t.tag_ = Tag::AS_EXPR;
return;
}
if (nlohmann_json_j.contains("as_int")) {
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get<double>());
nlohmann_json_t.tag_ = Tag::AS_INT;
return;
}
}
};
class SymBool {
struct Void {};
@ -468,6 +522,58 @@ class SymIntArgument {
}
};
class SymFloatArgument {
struct Void {};
public:
enum class Tag {
AS_NAME, AS_FLOAT
};
private:
std::variant<Void, std::string, double> variant_;
Tag tag_;
public:
Tag tag() const {
return tag_;
}
const std::string& get_as_name() const {
return std::get<1>(variant_);
}
const double& get_as_float() const {
return std::get<2>(variant_);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloatArgument& nlohmann_json_t) {
if (nlohmann_json_t.tag_ == Tag::AS_NAME) {
nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
return;
}
}
friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloatArgument& nlohmann_json_t) {
if (nlohmann_json_j.contains("as_name")) {
nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get<std::string>());
nlohmann_json_t.tag_ = Tag::AS_NAME;
return;
}
if (nlohmann_json_j.contains("as_float")) {
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get<double>());
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
return;
}
}
};
class SymBoolArgument {
struct Void {};
@ -643,11 +749,11 @@ class Argument {
public:
enum class Tag {
AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR
AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR
};
private:
std::variant<Void, std::tuple<>, TensorArgument, std::vector<TensorArgument>, int64_t, std::vector<int64_t>, double, std::vector<double>, std::string, std::vector<std::string>, SymIntArgument, std::vector<SymIntArgument>, ScalarType, MemoryFormat, Layout, Device, bool, std::vector<bool>, SymBoolArgument, std::vector<SymBoolArgument>, GraphArgument, std::vector<OptionalTensorArgument>, CustomObjArgument, std::string> variant_;
std::variant<Void, std::tuple<>, TensorArgument, std::vector<TensorArgument>, int64_t, std::vector<int64_t>, double, std::vector<double>, std::string, std::vector<std::string>, SymIntArgument, std::vector<SymIntArgument>, SymFloatArgument, std::vector<SymFloatArgument>, ScalarType, MemoryFormat, Layout, Device, bool, std::vector<bool>, SymBoolArgument, std::vector<SymBoolArgument>, GraphArgument, std::vector<OptionalTensorArgument>, CustomObjArgument, std::string> variant_;
Tag tag_;
public:
@ -699,54 +805,62 @@ class Argument {
return std::get<11>(variant_);
}
const ScalarType& get_as_scalar_type() const {
const SymFloatArgument& get_as_sym_float() const {
return std::get<12>(variant_);
}
const MemoryFormat& get_as_memory_format() const {
const std::vector<SymFloatArgument>& get_as_sym_floats() const {
return std::get<13>(variant_);
}
const Layout& get_as_layout() const {
const ScalarType& get_as_scalar_type() const {
return std::get<14>(variant_);
}
const Device& get_as_device() const {
const MemoryFormat& get_as_memory_format() const {
return std::get<15>(variant_);
}
const bool& get_as_bool() const {
const Layout& get_as_layout() const {
return std::get<16>(variant_);
}
const std::vector<bool>& get_as_bools() const {
const Device& get_as_device() const {
return std::get<17>(variant_);
}
const SymBoolArgument& get_as_sym_bool() const {
const bool& get_as_bool() const {
return std::get<18>(variant_);
}
const std::vector<SymBoolArgument>& get_as_sym_bools() const {
const std::vector<bool>& get_as_bools() const {
return std::get<19>(variant_);
}
const GraphArgument& get_as_graph() const {
const SymBoolArgument& get_as_sym_bool() const {
return std::get<20>(variant_);
}
const std::vector<OptionalTensorArgument>& get_as_optional_tensors() const {
const std::vector<SymBoolArgument>& get_as_sym_bools() const {
return std::get<21>(variant_);
}
const CustomObjArgument& get_as_custom_obj() const {
const GraphArgument& get_as_graph() const {
return std::get<22>(variant_);
}
const std::string& get_as_operator() const {
const std::vector<OptionalTensorArgument>& get_as_optional_tensors() const {
return std::get<23>(variant_);
}
const CustomObjArgument& get_as_custom_obj() const {
return std::get<24>(variant_);
}
const std::string& get_as_operator() const {
return std::get<25>(variant_);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) {
if (nlohmann_json_t.tag_ == Tag::AS_NONE) {
@ -793,6 +907,14 @@ class Argument {
nlohmann_json_j["as_sym_ints"] = nlohmann_json_t.get_as_sym_ints();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOAT) {
nlohmann_json_j["as_sym_float"] = nlohmann_json_t.get_as_sym_float();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOATS) {
nlohmann_json_j["as_sym_floats"] = nlohmann_json_t.get_as_sym_floats();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_SCALAR_TYPE) {
nlohmann_json_j["as_scalar_type"] = nlohmann_json_t.get_as_scalar_type();
return;
@ -900,63 +1022,73 @@ class Argument {
nlohmann_json_t.tag_ = Tag::AS_SYM_INTS;
return;
}
if (nlohmann_json_j.contains("as_sym_float")) {
nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_sym_float").template get<SymFloatArgument>());
nlohmann_json_t.tag_ = Tag::AS_SYM_FLOAT;
return;
}
if (nlohmann_json_j.contains("as_sym_floats")) {
nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_sym_floats").template get<std::vector<SymFloatArgument>>());
nlohmann_json_t.tag_ = Tag::AS_SYM_FLOATS;
return;
}
if (nlohmann_json_j.contains("as_scalar_type")) {
nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_scalar_type").template get<ScalarType>());
nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_scalar_type").template get<ScalarType>());
nlohmann_json_t.tag_ = Tag::AS_SCALAR_TYPE;
return;
}
if (nlohmann_json_j.contains("as_memory_format")) {
nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_memory_format").template get<MemoryFormat>());
nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_memory_format").template get<MemoryFormat>());
nlohmann_json_t.tag_ = Tag::AS_MEMORY_FORMAT;
return;
}
if (nlohmann_json_j.contains("as_layout")) {
nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_layout").template get<Layout>());
nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_layout").template get<Layout>());
nlohmann_json_t.tag_ = Tag::AS_LAYOUT;
return;
}
if (nlohmann_json_j.contains("as_device")) {
nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_device").template get<Device>());
nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_device").template get<Device>());
nlohmann_json_t.tag_ = Tag::AS_DEVICE;
return;
}
if (nlohmann_json_j.contains("as_bool")) {
nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_bool").template get<bool>());
nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_bool").template get<bool>());
nlohmann_json_t.tag_ = Tag::AS_BOOL;
return;
}
if (nlohmann_json_j.contains("as_bools")) {
nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_bools").template get<std::vector<bool>>());
nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_bools").template get<std::vector<bool>>());
nlohmann_json_t.tag_ = Tag::AS_BOOLS;
return;
}
if (nlohmann_json_j.contains("as_sym_bool")) {
nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_sym_bool").template get<SymBoolArgument>());
nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_sym_bool").template get<SymBoolArgument>());
nlohmann_json_t.tag_ = Tag::AS_SYM_BOOL;
return;
}
if (nlohmann_json_j.contains("as_sym_bools")) {
nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_sym_bools").template get<std::vector<SymBoolArgument>>());
nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_sym_bools").template get<std::vector<SymBoolArgument>>());
nlohmann_json_t.tag_ = Tag::AS_SYM_BOOLS;
return;
}
if (nlohmann_json_j.contains("as_graph")) {
nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_graph").template get<GraphArgument>());
nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_graph").template get<GraphArgument>());
nlohmann_json_t.tag_ = Tag::AS_GRAPH;
return;
}
if (nlohmann_json_j.contains("as_optional_tensors")) {
nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_optional_tensors").template get<std::vector<OptionalTensorArgument>>());
nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_optional_tensors").template get<std::vector<OptionalTensorArgument>>());
nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSORS;
return;
}
if (nlohmann_json_j.contains("as_custom_obj")) {
nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_custom_obj").template get<CustomObjArgument>());
nlohmann_json_t.variant_.emplace<24>(nlohmann_json_j.at("as_custom_obj").template get<CustomObjArgument>());
nlohmann_json_t.tag_ = Tag::AS_CUSTOM_OBJ;
return;
}
if (nlohmann_json_j.contains("as_operator")) {
nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_operator").template get<std::string>());
nlohmann_json_t.variant_.emplace<25>(nlohmann_json_j.at("as_operator").template get<std::string>());
nlohmann_json_t.tag_ = Tag::AS_OPERATOR;
return;
}
@ -1021,6 +1153,7 @@ class Graph {
std::unordered_map<std::string, SymBool> sym_bool_values;
bool is_single_tensor_return = false;
std::unordered_map<std::string, CustomObjArgument> custom_obj_values = {};
std::unordered_map<std::string, SymFloat> sym_float_values = {};
public:
@ -1056,6 +1189,10 @@ class Graph {
return custom_obj_values;
}
const std::unordered_map<std::string, SymFloat>& get_sym_float_values() const {
return sym_float_values;
}
friend void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t);
friend void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t);
};
@ -1892,6 +2029,7 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_
nlohmann_json_j["sym_bool_values"] = nlohmann_json_t.sym_bool_values;
nlohmann_json_j["is_single_tensor_return"] = nlohmann_json_t.is_single_tensor_return;
nlohmann_json_j["custom_obj_values"] = nlohmann_json_t.custom_obj_values;
nlohmann_json_j["sym_float_values"] = nlohmann_json_t.sym_float_values;
}
inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t) {
@ -1904,6 +2042,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_jso
nlohmann_json_t.sym_bool_values = nlohmann_json_j.value("sym_bool_values", nlohmann_json_default_obj.sym_bool_values);
nlohmann_json_t.is_single_tensor_return = nlohmann_json_j.value("is_single_tensor_return", nlohmann_json_default_obj.is_single_tensor_return);
nlohmann_json_t.custom_obj_values = nlohmann_json_j.value("custom_obj_values", nlohmann_json_default_obj.custom_obj_values);
nlohmann_json_t.sym_float_values = nlohmann_json_j.value("sym_float_values", nlohmann_json_default_obj.sym_float_values);
}
inline void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t) {

View File

@ -80,6 +80,7 @@ from .graph_signature import ( # noqa: F401
OutputKind,
OutputSpec,
SymBoolArgument,
SymFloatArgument,
SymIntArgument,
TensorArgument,
TokenArgument,
@ -521,6 +522,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
return TensorArgument(name=new_ph.name)
elif isinstance(old_arg, SymIntArgument):
return SymIntArgument(name=new_ph.name)
elif isinstance(old_arg, SymFloatArgument):
return SymFloatArgument(name=new_ph.name)
elif isinstance(old_arg, SymBoolArgument):
return SymBoolArgument(name=new_ph.name)
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")

View File

@ -20,6 +20,7 @@ __all__ = [
"OutputKind",
"OutputSpec",
"SymIntArgument",
"SymFloatArgument",
"SymBoolArgument",
"TensorArgument",
]
@ -40,6 +41,11 @@ class SymIntArgument:
name: str
@dataclasses.dataclass
class SymFloatArgument:
name: str
@dataclasses.dataclass
class SymBoolArgument:
name: str
@ -61,6 +67,7 @@ class ConstantArgument:
ArgumentSpec = Union[
TensorArgument,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
ConstantArgument,
CustomObjArgument,
@ -94,6 +101,7 @@ class InputSpec:
(
TensorArgument,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
ConstantArgument,
CustomObjArgument,
@ -124,6 +132,7 @@ class OutputSpec:
(
TensorArgument,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
ConstantArgument,
TokenArgument,
@ -273,7 +282,13 @@ class ExportGraphSignature:
if isinstance(
s.arg,
(TensorArgument, SymIntArgument, SymBoolArgument, CustomObjArgument),
(
TensorArgument,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
CustomObjArgument,
),
):
user_inputs.append(s.arg.name)
elif isinstance(s.arg, ConstantArgument):
@ -294,7 +309,10 @@ class ExportGraphSignature:
]:
continue
if isinstance(s.arg, (TensorArgument, SymIntArgument, SymBoolArgument)):
if isinstance(
s.arg,
(TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
):
user_outputs.append(s.arg.name)
elif isinstance(s.arg, ConstantArgument):
user_outputs.append(s.arg.value)
@ -444,6 +462,7 @@ class ExportGraphSignature:
arg_types = (
TensorArgument,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
CustomObjArgument,
TokenArgument,
@ -478,7 +497,7 @@ def _immutable_dict(items):
def _make_argument_spec(node, token_names) -> ArgumentSpec:
from torch import ScriptObject, SymBool, SymInt
from torch import ScriptObject, SymBool, SymFloat, SymInt
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor
@ -496,6 +515,8 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
return TensorArgument(name=node.name)
elif isinstance(val, SymInt):
return SymIntArgument(name=node.name)
elif isinstance(val, SymFloat):
return SymFloatArgument(name=node.name)
elif isinstance(val, SymBool):
return SymBoolArgument(name=node.name)
elif isinstance(val, ScriptObject):

View File

@ -23,6 +23,7 @@ from torch.export.exported_program import (
InputKind,
ModuleCallSignature,
SymBoolArgument,
SymFloatArgument,
SymIntArgument,
TensorArgument,
)
@ -1023,7 +1024,13 @@ class _ModuleFrame:
input_nodes.append(None)
else:
assert isinstance(
input, (TensorArgument, SymIntArgument, SymBoolArgument)
input,
(
TensorArgument,
SymIntArgument,
SymBoolArgument,
SymFloatArgument,
),
)
input_nodes.append(
self.parent.remap_input(self.seen_nodes[input.name])
@ -1133,7 +1140,8 @@ class _ModuleFrame:
if signature is not None and self.parent is not None:
for output in signature.outputs:
if isinstance(
output, (TensorArgument, SymIntArgument, SymBoolArgument)
output,
(TensorArgument, SymIntArgument, SymBoolArgument, SymFloatArgument),
):
if output.name in self.seen_nodes:
orig_outputs.append(self.seen_nodes[output.name])

View File

@ -130,6 +130,7 @@ __all__ = [
"create_contiguous",
"ShapeEnv",
"is_concrete_int",
"is_concrete_float",
"guard_int",
"guard_float",
"guard_scalar",
@ -367,6 +368,25 @@ def is_concrete_int(a: Union[int, SymInt]) -> bool:
return False
def is_concrete_float(a: Union[float, SymFloat]) -> bool:
r"""Utility to check if underlying object
in SymInt is concrete value. Also returns
true if integer is passed in.
Args:
a (SymInt or float): Object to test if it float
"""
assert isinstance(a, (SymFloat, float))
if isinstance(a, float):
return True
if isinstance(a.node.expr, sympy.core.numbers.Float):
return True
return False
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
"""
Perform a guard on a symbolic boolean expression in a size oblivious way.