mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support symbolic builtin round in export (#139549)
Differential Revision: D65380866 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139549 Approved by: https://github.com/digantdesai, https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
54e680151b
commit
617b4538f1
@ -9252,6 +9252,35 @@ class GraphModule(torch.nn.Module):
|
||||
state_dict.keys(),
|
||||
)
|
||||
|
||||
@testing.expectedFailureSerDer # T202237665
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_dynamic_sym_round(self):
|
||||
class ModuleWithSymRound(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out_size = round(x.shape[0] / 2.0)
|
||||
return x[:out_size]
|
||||
|
||||
dim_min = 5
|
||||
dim_max = 10
|
||||
dynamic_shapes = {"x": {0: Dim("n", min=dim_min, max=dim_max)}}
|
||||
|
||||
module = ModuleWithSymRound()
|
||||
inp = (torch.randn(8),)
|
||||
ep = export(module, inp, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
# Expect builtin round in the export graph
|
||||
round_nodes = [
|
||||
n for n in ep.graph.nodes if n.op == "call_function" and n.target == round
|
||||
]
|
||||
self.assertEqual(len(round_nodes), 1)
|
||||
|
||||
# Check pre/post-export equality
|
||||
for i in range(dim_min, dim_max + 1):
|
||||
dyn_inp = (torch.randn(i),)
|
||||
export_res = ep.module()(*dyn_inp)
|
||||
ref_res = module(*dyn_inp)
|
||||
self.assertEqual(export_res, ref_res)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
@ -186,7 +186,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
if target == operator.getitem:
|
||||
value, key = args
|
||||
return self.callback.call_getitem(value, key, meta)
|
||||
elif getattr(target, "__module__", None) in {"_operator", "math"}:
|
||||
elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}:
|
||||
assert callable(target)
|
||||
return self.callback.call_sym(target, args, meta)
|
||||
elif target in _TORCH_SYM_OPS:
|
||||
|
@ -135,6 +135,7 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
math.ceil,
|
||||
math.floor,
|
||||
math.trunc,
|
||||
round,
|
||||
]
|
||||
|
||||
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
||||
|
Reference in New Issue
Block a user