Support symbolic builtin round in export (#139549)

Differential Revision: D65380866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139549
Approved by: https://github.com/digantdesai, https://github.com/angelayi
This commit is contained in:
Gregory Comer
2024-11-07 02:49:44 +00:00
committed by PyTorch MergeBot
parent 54e680151b
commit 617b4538f1
3 changed files with 31 additions and 1 deletions

View File

@ -9252,6 +9252,35 @@ class GraphModule(torch.nn.Module):
state_dict.keys(),
)
@testing.expectedFailureSerDer # T202237665
@testing.expectedFailureSerDerNonStrict
def test_dynamic_sym_round(self):
class ModuleWithSymRound(torch.nn.Module):
def forward(self, x):
out_size = round(x.shape[0] / 2.0)
return x[:out_size]
dim_min = 5
dim_max = 10
dynamic_shapes = {"x": {0: Dim("n", min=dim_min, max=dim_max)}}
module = ModuleWithSymRound()
inp = (torch.randn(8),)
ep = export(module, inp, dynamic_shapes=dynamic_shapes)
# Expect builtin round in the export graph
round_nodes = [
n for n in ep.graph.nodes if n.op == "call_function" and n.target == round
]
self.assertEqual(len(round_nodes), 1)
# Check pre/post-export equality
for i in range(dim_min, dim_max + 1):
dyn_inp = (torch.randn(i),)
export_res = ep.module()(*dyn_inp)
ref_res = module(*dyn_inp)
self.assertEqual(export_res, ref_res)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):

View File

@ -186,7 +186,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
if target == operator.getitem:
value, key = args
return self.callback.call_getitem(value, key, meta)
elif getattr(target, "__module__", None) in {"_operator", "math"}:
elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif target in _TORCH_SYM_OPS:

View File

@ -135,6 +135,7 @@ class Verifier(metaclass=_VerifierMeta):
math.ceil,
math.floor,
math.trunc,
round,
]
def allowed_op_types(self) -> Tuple[Type[Any], ...]: