diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index 932a11ab076c..d3097734c8a3 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -53,7 +53,7 @@ class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]): self.decomp_table = _core_aten_decompositions_post_autograd() for op in _collect_all_valid_cia_ops_for_aten_namespace(): - if op not in PRESERVED_ATEN_CIA_OPS: + if op not in PRESERVED_ATEN_CIA_OPS and op not in self.decomp_table: self.decomp_table[op] = _get_decomp_for_cia(op) # This is to track the *pending* deleted custom ops that haven't been materialized yet