mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR. This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601 Approved by: https://github.com/ezyang
230 lines
4.9 KiB
Python
230 lines
4.9 KiB
Python
import _collections_abc
|
|
import _weakrefset
|
|
import abc
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import copyreg
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import linecache
|
|
import logging
|
|
import multiprocessing
|
|
import operator
|
|
import os
|
|
import posixpath
|
|
import random
|
|
import re
|
|
import selectors
|
|
import signal
|
|
import tempfile
|
|
import threading
|
|
import tokenize
|
|
import traceback
|
|
import types
|
|
import typing
|
|
import unittest
|
|
import weakref
|
|
|
|
import torch
|
|
import torch._inductor.test_operators
|
|
|
|
|
|
try:
|
|
import torch._prims
|
|
|
|
# isort: split
|
|
# TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands
|
|
import torch._prims.utils
|
|
import torch._prims.wrappers
|
|
import torch._refs
|
|
import torch._refs.nn
|
|
import torch._refs.nn.functional
|
|
import torch._refs.special
|
|
|
|
HAS_PRIMS_REFS = True
|
|
except ImportError:
|
|
HAS_PRIMS_REFS = False
|
|
|
|
from . import comptime, config, external_utils
|
|
|
|
"""
|
|
A note on skipfiles:
|
|
|
|
Dynamo consults this file to determine whether code should be compiled or skipped.
|
|
|
|
A skip applies at the frame boundary, meaning dynamo either triggers a graph break
|
|
at the beginning of the frame or attempts to trace the whole frame. When skipping
|
|
a frame, recursively called frames are still traced by dynamo unless also skipped.
|
|
|
|
Skipfiles (skipped at the file level instead of function level) still apply on a
|
|
frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
|
|
|
|
@skip is a helper decorator that can be applied to your function to cause it to be
|
|
included here.
|
|
"""
|
|
|
|
|
|
def _strip_init_py(s):
|
|
return re.sub(r"__init__.py$", "", s)
|
|
|
|
|
|
def _module_dir(m: types.ModuleType):
|
|
return _strip_init_py(m.__file__)
|
|
|
|
|
|
SKIP_DIRS = [
|
|
# torch.*
|
|
_module_dir(torch),
|
|
# torchdynamo.*
|
|
os.path.dirname(__file__) + "/",
|
|
"<frozen importlib",
|
|
"<__array_function__ internals>",
|
|
] + [
|
|
# skip some standard libs
|
|
_module_dir(m)
|
|
for m in (
|
|
abc,
|
|
collections,
|
|
contextlib,
|
|
copy,
|
|
copyreg,
|
|
dataclasses,
|
|
enum,
|
|
functools,
|
|
importlib,
|
|
inspect,
|
|
linecache,
|
|
logging,
|
|
multiprocessing,
|
|
operator,
|
|
os,
|
|
posixpath,
|
|
random,
|
|
re,
|
|
selectors,
|
|
signal,
|
|
tempfile,
|
|
threading,
|
|
tokenize,
|
|
traceback,
|
|
types,
|
|
typing,
|
|
unittest,
|
|
weakref,
|
|
_collections_abc,
|
|
_weakrefset,
|
|
)
|
|
]
|
|
|
|
FILENAME_ALLOWLIST = {
|
|
torch.nn.Sequential.__init__.__code__.co_filename,
|
|
torch.set_rng_state.__code__.co_filename,
|
|
torch._inductor.test_operators.__file__,
|
|
# These are dynamo files!
|
|
external_utils.__file__,
|
|
comptime.__file__, # Want to inline these helpers
|
|
}
|
|
|
|
# Include optimizer code for tracing
|
|
FILENAME_ALLOWLIST |= {
|
|
inspect.getfile(obj)
|
|
for obj in torch.optim.__dict__.values()
|
|
if inspect.isclass(obj)
|
|
}
|
|
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
|
|
|
|
if HAS_PRIMS_REFS:
|
|
FILENAME_ALLOWLIST |= {
|
|
torch._prims.__file__,
|
|
torch._prims.utils.__file__,
|
|
torch._prims.wrappers.__file__,
|
|
torch._refs.__file__,
|
|
torch._refs.special.__file__,
|
|
torch._refs.nn.functional.__file__,
|
|
}
|
|
|
|
|
|
SKIP_DIRS_RE = None
|
|
|
|
|
|
def _recompile_re():
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
|
|
|
|
|
|
def add(import_name: str):
|
|
if isinstance(import_name, types.ModuleType):
|
|
return add(import_name.__name__)
|
|
assert isinstance(import_name, str)
|
|
module_spec = importlib.util.find_spec(import_name)
|
|
if not module_spec:
|
|
return
|
|
origin = module_spec.origin
|
|
if origin is None:
|
|
return
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS.append(_strip_init_py(origin))
|
|
_recompile_re()
|
|
|
|
|
|
def check(filename, allow_torch=False):
|
|
"""Should skip this file?"""
|
|
if filename is None:
|
|
return True
|
|
if filename in FILENAME_ALLOWLIST:
|
|
return False
|
|
if allow_torch and is_torch(filename):
|
|
return False
|
|
return bool(SKIP_DIRS_RE.match(filename))
|
|
|
|
|
|
# skip common third party libs
|
|
for _name in (
|
|
"functorch",
|
|
"intel_extension_for_pytorch",
|
|
"networkx",
|
|
"numpy",
|
|
"omegaconf",
|
|
"onnx",
|
|
"onnxruntime",
|
|
"onnx_tf",
|
|
"pandas",
|
|
"sklearn",
|
|
"tabulate",
|
|
"tensorflow",
|
|
"tensorrt",
|
|
"torch2trt",
|
|
"tqdm",
|
|
"tree",
|
|
"tvm",
|
|
"fx2trt_oss",
|
|
"xarray",
|
|
):
|
|
add(_name)
|
|
|
|
_recompile_re()
|
|
|
|
|
|
def is_torch_inline_allowed(filename):
|
|
return any(
|
|
filename.startswith(_module_dir(mod))
|
|
for mod in config.skipfiles_inline_module_allowlist
|
|
)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def dynamo_dir():
|
|
import torch._dynamo
|
|
|
|
return _module_dir(torch._dynamo)
|
|
|
|
|
|
def is_torch(filename):
|
|
if filename.startswith(dynamo_dir()):
|
|
return False
|
|
return filename.startswith(_module_dir(torch))
|