[export] Handle non OpNamespace type during decomposition. (#149431)

Summary:
Turns out we can have non OpNamespace object in torch.ops._dir.

We should just throw away those during iteration.

Test Plan: eyes

Differential Revision: D71417992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149431
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2025-03-20 15:36:15 +00:00
committed by PyTorch MergeBot
parent d67c1a027e
commit 80dfce2cc3

View File

@ -1111,16 +1111,16 @@ def _check_valid_to_preserve(op_overload: "OperatorBase"):
@functools.lru_cache(maxsize=1)
def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]:
return _collect_all_valid_cia_ops_for_namespace("aten")
return _collect_all_valid_cia_ops_for_namespace(torch.ops.aten)
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> set["OperatorBase"]:
def _collect_all_valid_cia_ops_for_namespace(
op_namespace: torch._ops._OpNamespace,
) -> set["OperatorBase"]:
# Step 1: Materialize all ops from C++ dispatcher
_materialize_cpp_cia_ops()
# Step 2: Query all ops from python dispatcher
assert hasattr(torch.ops, namespace)
op_namespace = getattr(torch.ops, namespace)
cia_ops = set()
for op in op_namespace:
op_packet = getattr(op_namespace, op)
@ -1150,7 +1150,10 @@ def _collect_all_valid_cia_ops() -> set["OperatorBase"]:
for op_namespace_name in torch.ops._dir:
# The reason we split here is because aten ops are safe to cache.
if op_namespace_name != "aten":
cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace_name)
assert hasattr(torch.ops, op_namespace_name)
op_namespace = getattr(torch.ops, op_namespace_name)
if isinstance(op_namespace, torch._ops._OpNamespace):
cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace)
else:
cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace()
return cia_ops