pyfmt lint torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py (#154488)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154488
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483, #154484, #154485, #154487
This commit is contained in:
Laith Sakka
2025-05-27 21:39:07 -07:00
committed by PyTorch MergeBot
parent dfe0f48123
commit 66ac724b56
2 changed files with 4 additions and 6 deletions

View File

@ -1316,8 +1316,6 @@ exclude_patterns = [
'torch/_export/passes/const_prop_pass.py',
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
'torch/_export/passes/replace_sym_size_ops_pass.py',
'torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py',
'torch/_export/serde/__init__.py',
'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py',
'torch/testing/_internal/autograd_function_db.py',

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch._ops import OpOverload, HigherOrderOperator
from torch._export.error import InternalError
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._ops import HigherOrderOperator, OpOverload
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
@ -25,9 +26,7 @@ def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOver
if is_view_op(schema) and schema.name.startswith("aten::"):
view_op_name = schema.name.split("::")[1]
view_op_overload = (
schema.overload_name
if schema.overload_name != ""
else "default"
schema.overload_name if schema.overload_name != "" else "default"
)
view_copy_op_name = view_op_name + "_copy"
if not hasattr(torch.ops.aten, view_copy_op_name):
@ -50,6 +49,7 @@ class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
program. This pass replaces view ops with view copy ops for backends that
need AOT memory planning.
"""
def call_operator(self, op, args, kwargs, meta):
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
return super().call_operator(