mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c3fddb9cc
commit
3555ebb63d
2
.flake8
2
.flake8
@ -19,6 +19,8 @@ ignore =
|
||||
G100,G101,G200
|
||||
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
||||
# SIM104 is already covered by pyupgrade ruff
|
||||
SIM104,
|
||||
# flake8-simplify code styles
|
||||
SIM102,SIM103,SIM106,SIM112,
|
||||
# TorchFix codes that don't make sense for PyTorch itself:
|
||||
|
1
.github/scripts/filter_test_configs.py
vendored
1
.github/scripts/filter_test_configs.py
vendored
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# ruff: noqa: LOG015
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
@ -1456,7 +1456,7 @@ init_command = [
|
||||
'black==23.12.1',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==5.13.2',
|
||||
'ruff==0.9.8', # sync with RUFF
|
||||
'ruff==0.11.8', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
@ -1542,7 +1542,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.9.8', # sync with PYFMT
|
||||
'ruff==0.11.8', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
@ -1379,7 +1379,7 @@ def _produce_dynamic_shapes_for_export(path, x):
|
||||
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return None
|
||||
return {i: Dim.AUTO for i in getattr(x, "_dynamo_dynamic_indices", {})}
|
||||
return dict.fromkeys(getattr(x, "_dynamo_dynamic_indices", {}), Dim.AUTO)
|
||||
|
||||
|
||||
class AOTInductorModelCache:
|
||||
@ -1671,7 +1671,7 @@ def maybe_snapshot_memory(should_snapshot_memory, suffix):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Failed to save memory snapshot, %s", e)
|
||||
log.error("Failed to save memory snapshot, %s", e)
|
||||
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
@ -2687,7 +2687,7 @@ class BenchmarkRunner:
|
||||
experiment,
|
||||
tag,
|
||||
):
|
||||
logging.info("Minifying %s...", name)
|
||||
log.info("Minifying %s...", name)
|
||||
os.environ["TORCH_COMPILE_DEBUG"] = "1"
|
||||
os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
|
||||
os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
|
||||
@ -2702,9 +2702,9 @@ class BenchmarkRunner:
|
||||
try:
|
||||
shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
|
||||
except OSError:
|
||||
logging.error("Could not find repro script for model %s", name)
|
||||
log.error("Could not find repro script for model %s", name)
|
||||
else:
|
||||
logging.info(
|
||||
log.info(
|
||||
"Repro script for model %s with minified graph saved to %s",
|
||||
name,
|
||||
repro_dir,
|
||||
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa: F821
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
@ -48,7 +49,6 @@ def pip_install(package):
|
||||
|
||||
# Disable the flake warnings for the imports. Flake8 does not provide a way to
|
||||
# disable just warning for the entire file. Disabling flake8 entirely.
|
||||
# flake8: noqa
|
||||
imports = [
|
||||
"AlbertForPreTraining",
|
||||
"AutoConfig",
|
||||
@ -111,7 +111,7 @@ BATCH_SIZE_KNOWN_MODELS = {}
|
||||
# Get the list of models and their batch sizes
|
||||
MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
|
||||
assert os.path.exists(MODELS_FILENAME)
|
||||
with open(MODELS_FILENAME, "r") as fh:
|
||||
with open(MODELS_FILENAME) as fh:
|
||||
lines = fh.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
@ -166,7 +166,7 @@ def get_sequence_length(model_cls, model_name):
|
||||
seq_length = 10000 # NB: a more realistic size is 155136
|
||||
else:
|
||||
log.info(
|
||||
f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
|
||||
f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" # noqa: G004
|
||||
)
|
||||
seq_length = 128
|
||||
return seq_length
|
||||
@ -204,22 +204,16 @@ def generate_inputs_for_model(
|
||||
|
||||
input_dict = {"input_ids": input}
|
||||
|
||||
if (
|
||||
model_name.startswith("T5")
|
||||
or model_name.startswith("M2M100")
|
||||
or model_name.startswith("MT5")
|
||||
or model_cls
|
||||
in [
|
||||
BlenderbotModel,
|
||||
BlenderbotSmallModel,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
PegasusModel,
|
||||
PegasusForConditionalGeneration,
|
||||
MarianModel,
|
||||
MarianMTModel,
|
||||
]
|
||||
):
|
||||
if model_name.startswith(("T5", "M2M100", "MT5")) or model_cls in [
|
||||
BlenderbotModel,
|
||||
BlenderbotSmallModel,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
PegasusModel,
|
||||
PegasusForConditionalGeneration,
|
||||
MarianModel,
|
||||
MarianMTModel,
|
||||
]:
|
||||
input_dict["decoder_input_ids"] = input
|
||||
|
||||
if model_name.startswith("Lxmert"):
|
||||
@ -251,11 +245,8 @@ def generate_inputs_for_model(
|
||||
device, 0, seq_length, (bs,)
|
||||
)
|
||||
input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
|
||||
elif (
|
||||
model_name.endswith("MaskedLM")
|
||||
or model_name.endswith("HeadModel")
|
||||
or model_name.endswith("CausalLM")
|
||||
or model_name.endswith("DoubleHeadsModel")
|
||||
elif model_name.endswith(
|
||||
("MaskedLM", "HeadModel", "CausalLM", "DoubleHeadsModel")
|
||||
):
|
||||
input_dict["labels"] = rand_int_tensor(
|
||||
device, 0, vocab_size, (bs, seq_length)
|
||||
@ -429,7 +420,7 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||
elif batch_size is None:
|
||||
batch_size_default = 16
|
||||
log.info(
|
||||
f"Batch size not specified for {model_name}. Setting batch_size=16"
|
||||
f"Batch size not specified for {model_name}. Setting batch_size=16" # noqa: G004
|
||||
)
|
||||
|
||||
if batch_size is None:
|
||||
@ -438,7 +429,7 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||
if model_name in batch_size_divisors:
|
||||
batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
|
||||
log.info(
|
||||
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
|
||||
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004
|
||||
)
|
||||
|
||||
example_inputs = generate_inputs_for_model(
|
||||
@ -474,8 +465,8 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||
if index < start or index >= end:
|
||||
continue
|
||||
if (
|
||||
not re.search("|".join(args.filter), model_name, re.I)
|
||||
or re.search("|".join(args.exclude), model_name, re.I)
|
||||
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||
or model_name in args.exclude_exact
|
||||
or model_name in self.skip_models
|
||||
):
|
||||
@ -621,7 +612,7 @@ def refresh_model_names_and_batch_sizes():
|
||||
+ [f"--output={MODELS_FILENAME}"]
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
log.warning(f"Failed to find suitable batch size for {model_name}")
|
||||
log.warning(f"Failed to find suitable batch size for {model_name}") # noqa: G004
|
||||
|
||||
|
||||
def huggingface_main():
|
||||
|
@ -1,6 +1,5 @@
|
||||
# flake8: noqa
|
||||
# flake8: noqa: B902
|
||||
|
||||
import triton
|
||||
from prettytable import PrettyTable
|
||||
|
||||
import torch
|
||||
@ -18,7 +17,7 @@ torch.manual_seed(0)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
|
||||
class Func(object):
|
||||
class Func:
|
||||
# mm
|
||||
@torch._dynamo.optimize("inductor")
|
||||
def mm(a, b, bias):
|
||||
@ -45,7 +44,9 @@ class Func(object):
|
||||
return torch.relu(y)
|
||||
|
||||
|
||||
def bench(shape, layer_id, p, fusion_types=[""]):
|
||||
def bench(shape, layer_id, p, fusion_types=None):
|
||||
if fusion_types is None:
|
||||
fusion_types = [""]
|
||||
dtype = torch.float16
|
||||
M, K = shape[0]
|
||||
_, N = shape[1]
|
||||
@ -60,7 +61,7 @@ def bench(shape, layer_id, p, fusion_types=[""]):
|
||||
row = [layer_id]
|
||||
for fusion_type in fusion_types:
|
||||
if fusion_type == "":
|
||||
fn_mm = getattr(Func, "mm")
|
||||
fn_mm = Func.mm
|
||||
else:
|
||||
fn_mm = getattr(Func, f"mm_{fusion_type}")
|
||||
|
||||
|
@ -1450,7 +1450,7 @@ class DashboardUpdater:
|
||||
try:
|
||||
RegressionTracker(self.args).diff()
|
||||
except Exception:
|
||||
logging.exception("")
|
||||
log.exception("")
|
||||
with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
|
||||
gh_fh.write("")
|
||||
|
||||
|
@ -8,6 +8,9 @@ import pandas as pd
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gmean(s):
|
||||
return s.product() ** (1 / len(s))
|
||||
|
||||
@ -67,7 +70,7 @@ def main(directory, amp, float32, perf_compare):
|
||||
try:
|
||||
dfs[os.path.basename(f)].append(pd.read_csv(f))
|
||||
except Exception:
|
||||
logging.warning("failed parsing %s", f)
|
||||
log.warning("failed parsing %s", f)
|
||||
raise
|
||||
|
||||
# dtype -> statistic -> benchmark -> compiler -> value
|
||||
|
@ -43,7 +43,7 @@ def torchao_optimize_ctx(quantization: str):
|
||||
from torchao.quantization.autoquant import AUTOQUANT_CACHE
|
||||
|
||||
if len(AUTOQUANT_CACHE) == 0:
|
||||
raise Exception( # noqa: TRY002`
|
||||
raise Exception( # noqa: TRY002
|
||||
"NotAutoquantizable"
|
||||
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
|
||||
)
|
||||
|
@ -8,6 +8,8 @@ import pandas as pd
|
||||
from torch._functorch.benchmark_utils import compute_utilization
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# process the chrome traces output by the pytorch profiler
|
||||
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
|
||||
# the runtimes file should have format (model_name, runtime)
|
||||
@ -65,7 +67,7 @@ def main():
|
||||
)
|
||||
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
||||
except BaseException:
|
||||
logging.exception("%s, ERROR", filename)
|
||||
log.exception("%s, ERROR", filename)
|
||||
print(f"{filename}, ERROR")
|
||||
|
||||
|
||||
|
@ -73,6 +73,23 @@ quote-style = "double"
|
||||
|
||||
[tool.ruff.lint]
|
||||
# NOTE: Synchoronize the ignores with .flake8
|
||||
external = [
|
||||
"B001",
|
||||
"B902",
|
||||
"B950",
|
||||
"E121",
|
||||
"E122",
|
||||
"E128",
|
||||
"E131",
|
||||
"E704",
|
||||
"E723",
|
||||
"F723",
|
||||
"F812",
|
||||
"P201",
|
||||
"P204",
|
||||
"T484",
|
||||
"TOR901",
|
||||
]
|
||||
ignore = [
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
"B007", "B008", "B017",
|
||||
@ -108,6 +125,8 @@ ignore = [
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
"UP007", # keep-runtime-typing
|
||||
"TC006",
|
||||
"TC007",
|
||||
]
|
||||
select = [
|
||||
"B",
|
||||
@ -173,7 +192,7 @@ select = [
|
||||
"RUF030", # No print statement in assert
|
||||
"S324", # for hashlib FIPS compliance
|
||||
"SLOT",
|
||||
"TCH",
|
||||
"TC",
|
||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||
"TRY203",
|
||||
"TRY401", # verbose-log-message
|
||||
@ -187,6 +206,12 @@ select = [
|
||||
"functorch/notebooks/**" = [
|
||||
"F401",
|
||||
]
|
||||
"test/export/**" = [
|
||||
"PGH004"
|
||||
]
|
||||
"test/typing/**" = [
|
||||
"PGH004"
|
||||
]
|
||||
"test/typing/reveal/**" = [
|
||||
"F821",
|
||||
]
|
||||
@ -200,6 +225,9 @@ select = [
|
||||
"test/dynamo/test_debug_utils.py" = [
|
||||
"UP037",
|
||||
]
|
||||
"test/dynamo/test_misc.py" = [
|
||||
"PGH004",
|
||||
]
|
||||
"test/jit/**" = [
|
||||
"PLR0133", # tests require this for JIT
|
||||
"PYI",
|
||||
@ -212,12 +240,20 @@ select = [
|
||||
"RUF015",
|
||||
"UP", # We don't want to modify the jit test as they test specify syntax
|
||||
]
|
||||
"test/inductor/s429861_repro.py" = [
|
||||
"PGH004",
|
||||
]
|
||||
"test/inductor/test_torchinductor.py" = [
|
||||
"UP037",
|
||||
]
|
||||
# autogenerated #TODO figure out why file level noqa is ignored
|
||||
"torch/_appdirs.py" = ["PGH004"]
|
||||
"torch/jit/_shape_functions.py" = ["PGH004"]
|
||||
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
|
||||
"torch/_inductor/autoheuristic/artifacts/**" = ["F401", "F501"]
|
||||
"torch/_inductor/codegen/**" = [
|
||||
"PGH004"
|
||||
]
|
||||
"torchgen/api/types/__init__.py" = [
|
||||
"F401",
|
||||
"F403",
|
||||
@ -232,3 +268,6 @@ select = [
|
||||
"torch/_vendor/**" = [
|
||||
"UP", # No need to mess with _vendor
|
||||
]
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
|
@ -24,9 +24,15 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
|
||||
|
||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
|
||||
|
||||
|
||||
class TestQuantizedSparseKernels(TestCase):
|
||||
@ -78,10 +84,10 @@ class TestQuantizedSparseKernels(TestCase):
|
||||
|
||||
for use_channelwise, dynamic_mode in product([True, False], [True, False]):
|
||||
if qengine_is_fbgemm() and dynamic_mode:
|
||||
logging.info("dynamic sparse qlinear is only available in qnnpack")
|
||||
logger.info("dynamic sparse qlinear is only available in qnnpack")
|
||||
continue
|
||||
if qengine_is_qnnpack() and not dynamic_mode:
|
||||
logging.info("static sparse qlinear is only available in fbgemm")
|
||||
logger.info("static sparse qlinear is only available in fbgemm")
|
||||
continue
|
||||
if use_channelwise:
|
||||
W_q = torch.quantize_per_channel(
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa
|
||||
# flake8: noqa: B950
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
@ -13,7 +13,6 @@ from torch import _inductor as inductor
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
|
||||
|
@ -1,17 +1,13 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa: TRY002
|
||||
# flake8: noqa
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import itertools
|
||||
import types
|
||||
import unittest
|
||||
import weakref
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
@ -22,8 +18,6 @@ import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.utils import dict_items
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class SimpleDict(dict):
|
||||
@ -435,7 +429,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
config = dotdict({"a": 1, "b": 2})
|
||||
|
||||
def fn(x):
|
||||
x2 = x * 2
|
||||
x2 = x * 2 # noqa: F841
|
||||
x3 = x * config.get("a", 3)
|
||||
return x3
|
||||
|
||||
@ -643,8 +637,8 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
):
|
||||
|
||||
class CustomDict(super_class):
|
||||
def __new__(self, *args, **kwargs):
|
||||
return super().__new__(self, *args, **kwargs)
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super().__new__(cls, *args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -806,7 +800,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
d = {"a": 2, "b": 3, "c": 5 * x}
|
||||
mp = types.MappingProxyType(d)
|
||||
y = torch.sin(x * mp["a"])
|
||||
for k, v in mp.items():
|
||||
for k, v in mp.items(): # noqa: PERF102
|
||||
y += torch.cos(x * v)
|
||||
return mp
|
||||
|
||||
@ -823,7 +817,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
def fn(x):
|
||||
mp = types.MappingProxyType(d)
|
||||
y = torch.sin(x * mp["a"])
|
||||
for k, v in mp.items():
|
||||
for k, v in mp.items(): # noqa: PERF102
|
||||
y += torch.cos(x * v)
|
||||
d["d"] = 4
|
||||
return mp
|
||||
@ -844,7 +838,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def fn(x, mp):
|
||||
y = torch.sin(x * mp["a"])
|
||||
for k, v in mp.items():
|
||||
for k, v in mp.items(): # noqa: PERF102
|
||||
y += torch.cos(x * v)
|
||||
return y
|
||||
|
||||
@ -939,7 +933,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_items_type(self):
|
||||
def fn():
|
||||
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)})
|
||||
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418
|
||||
return d.items()
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa
|
||||
# flake8: noqa: B950
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
|
@ -10,7 +10,16 @@ from torch._C import parse_schema, Tag
|
||||
|
||||
|
||||
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
|
||||
log = logging.getLogger("log")
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(FORMAT)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log.addHandler(handler)
|
||||
log.propagate = False # Avoid double logging if root logger has handlers
|
||||
|
||||
# How to run this test locally:
|
||||
# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
|
||||
@ -259,10 +268,10 @@ def check_bc(existing_schemas):
|
||||
is_allow_list, trust_not_core_aten = allow_listed(existing_schema)
|
||||
if is_allow_list:
|
||||
if trust_not_core_aten or not is_core_aten_op(existing_schema):
|
||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
log.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
log.info(
|
||||
"schema: %s found on allowlist, but is a core ATen op, checking BC. "
|
||||
"NOTE: If you have removed an operator we will conservatively assume that "
|
||||
"it is a core ATen op. If the operator you removed is not a core ATen op, "
|
||||
@ -272,13 +281,13 @@ def check_bc(existing_schemas):
|
||||
)
|
||||
if has_valid_upgraders(existing_schema, version_map):
|
||||
if not is_core_aten_op(existing_schema):
|
||||
logging.info("schema: %s has valid upgrader, skipping", existing_schema)
|
||||
log.info("schema: %s has valid upgrader, skipping", existing_schema)
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
log.info(
|
||||
"schema: %s has a valid upgrader, but is a core ATen op, checking BC"
|
||||
)
|
||||
logging.debug("processing existing schema: %s", existing_schema)
|
||||
log.debug("processing existing schema: %s", existing_schema)
|
||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||
found = False
|
||||
for matching_new_schema in matching_new_schemas:
|
||||
@ -286,7 +295,7 @@ def check_bc(existing_schemas):
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
logging.warning(
|
||||
log.warning(
|
||||
"Can NOT find backward compatible schemas after changes "
|
||||
"for schema %s from the following candidates:\n[\n%s\n]",
|
||||
str(existing_schema),
|
||||
@ -296,9 +305,9 @@ def check_bc(existing_schemas):
|
||||
broken_ops.append(str(existing_schema))
|
||||
is_bc = False
|
||||
if is_bc:
|
||||
logging.info("Found backward compatible schemas for all existing schemas")
|
||||
log.info("Found backward compatible schemas for all existing schemas")
|
||||
else:
|
||||
logging.warning(
|
||||
log.warning(
|
||||
"The PR is introducing backward incompatible changes to the "
|
||||
"operator library. Please contact PyTorch team to confirm "
|
||||
"whether this change is wanted or not. \n\nBroken ops: "
|
||||
@ -315,9 +324,9 @@ def check_fc(existing_schemas):
|
||||
for existing_schema in existing_schemas:
|
||||
is_allow_list, _ = allow_listed(existing_schema)
|
||||
if is_allow_list:
|
||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
log.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
continue
|
||||
logging.info("processing existing schema: %s", existing_schema)
|
||||
log.info("processing existing schema: %s", existing_schema)
|
||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||
found = False
|
||||
possible_failure_reasons = []
|
||||
@ -331,13 +340,13 @@ def check_fc(existing_schemas):
|
||||
if reason != "":
|
||||
possible_failure_reasons.append(reason)
|
||||
if not found:
|
||||
logging.warning(
|
||||
log.warning(
|
||||
"Can NOT find forward compatible schemas after changes "
|
||||
"for schema %s from the following candidates:\n[\n\t%s\n]",
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
logging.warning(
|
||||
log.warning(
|
||||
"Refer to following reasons for failure "
|
||||
"to find FC schema:\n[\n%s\n]",
|
||||
"\n\t".join(str(r) for r in possible_failure_reasons),
|
||||
@ -345,9 +354,9 @@ def check_fc(existing_schemas):
|
||||
broken_ops.append(str(existing_schema))
|
||||
is_fc = False
|
||||
if is_fc:
|
||||
logging.info("Found forward compatible schemas for all existing schemas")
|
||||
log.info("Found forward compatible schemas for all existing schemas")
|
||||
else:
|
||||
logging.warning(
|
||||
log.warning(
|
||||
"The PR is introducing a potentially forward incompatible changes to the "
|
||||
"operator library. Please contact PyTorch team to confirm "
|
||||
"whether this change is wanted or not. \n\nBroken ops: "
|
||||
@ -374,7 +383,7 @@ if __name__ == "__main__":
|
||||
break
|
||||
|
||||
if dont_parse(line.strip()):
|
||||
logging.info("Not parsing schema line: %s", line.strip())
|
||||
log.info("Not parsing schema line: %s", line.strip())
|
||||
continue
|
||||
s = parse_schema(line.strip())
|
||||
slist.append(s)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
# ruff: noqa: F841
|
||||
# flake8: noqa
|
||||
import collections
|
||||
import collections.abc
|
||||
import copy
|
||||
@ -296,7 +295,7 @@ class TestJointOps(TestCase):
|
||||
self.s.z = ["z"]
|
||||
p = pickle.dumps(self.s, i)
|
||||
dup = pickle.loads(p)
|
||||
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
|
||||
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup)) # noqa: UP031
|
||||
if type(self.s) not in (OrderedSet, frozenset):
|
||||
self.assertEqual(self.s.x, dup.x)
|
||||
self.assertEqual(self.s.z, dup.z)
|
||||
@ -390,7 +389,7 @@ class TestJointOps(TestCase):
|
||||
self.assertEqual(repr(s), "{OrderedSet(...)}")
|
||||
else:
|
||||
name = repr(s).partition("(")[0] # strip class name
|
||||
self.assertEqual(repr(s), "%s({%s(...)})" % (name, name))
|
||||
self.assertEqual(repr(s), "%s({%s(...)})" % (name, name)) # noqa: UP031
|
||||
|
||||
@unittest.skip("Different hashing")
|
||||
def test_do_not_rehash_dict_keys(self):
|
||||
@ -454,7 +453,7 @@ class TestSet(TestJointOps, TestCase):
|
||||
|
||||
def test_set_literal_insertion_order(self):
|
||||
# SF Issue #26020 -- Expect left to right insertion
|
||||
s = {1, 1.0, True}
|
||||
s = {1, 1.0, True} # noqa: B033
|
||||
self.assertEqual(len(s), 1)
|
||||
stored_value = s.pop()
|
||||
self.assertEqual(type(stored_value), int)
|
||||
@ -715,19 +714,19 @@ class TestSet(TestJointOps, TestCase):
|
||||
myset = {1, 2, 3}
|
||||
|
||||
myobj = TestRichSetCompare()
|
||||
myset < myobj
|
||||
myset < myobj # noqa: B015
|
||||
self.assertTrue(myobj.gt_called)
|
||||
|
||||
myobj = TestRichSetCompare()
|
||||
myset > myobj
|
||||
myset > myobj # noqa: B015
|
||||
self.assertTrue(myobj.lt_called)
|
||||
|
||||
myobj = TestRichSetCompare()
|
||||
myset <= myobj
|
||||
myset <= myobj # noqa: B015
|
||||
self.assertTrue(myobj.ge_called)
|
||||
|
||||
myobj = TestRichSetCompare()
|
||||
myset >= myobj
|
||||
myset >= myobj # noqa: B015
|
||||
self.assertTrue(myobj.le_called)
|
||||
|
||||
|
||||
@ -834,7 +833,9 @@ class TestBasicOps(TestCase):
|
||||
p = pickle.dumps(self.OrderedSet, proto)
|
||||
copy = pickle.loads(p)
|
||||
self.assertEqual(
|
||||
self.OrderedSet, copy, "%s != %s" % (self.OrderedSet, copy)
|
||||
self.OrderedSet,
|
||||
copy,
|
||||
"%s != %s" % (self.OrderedSet, copy), # noqa: UP031
|
||||
)
|
||||
|
||||
def test_issue_37219(self):
|
||||
@ -1195,7 +1196,7 @@ class TestMutate(TestCase):
|
||||
expected_len = 0
|
||||
for v in self.values:
|
||||
tmp.add(v)
|
||||
expected_len += 1
|
||||
expected_len += 1 # noqa: SIM113
|
||||
self.assertEqual(len(tmp), expected_len)
|
||||
self.assertEqual(tmp, self.OrderedSet)
|
||||
|
||||
@ -1518,7 +1519,7 @@ class TestOnlySetsString(TestOnlySetsInBinaryOps, TestCase):
|
||||
class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, TestCase):
|
||||
def setUp(self):
|
||||
def gen():
|
||||
for i in range(0, 10, 2):
|
||||
for i in range(0, 10, 2): # noqa: UP028
|
||||
yield i
|
||||
|
||||
self.OrderedSet = OrderedSet((1, 2, 3))
|
||||
@ -1541,7 +1542,7 @@ class TestCopying:
|
||||
|
||||
def test_deep_copy(self):
|
||||
dup = copy.deepcopy(self.OrderedSet)
|
||||
##print type(dup), repr(dup)
|
||||
# print type(dup), repr(dup)
|
||||
dup_list = sorted(dup, key=repr)
|
||||
set_list = sorted(self.OrderedSet, key=repr)
|
||||
self.assertEqual(len(dup_list), len(set_list))
|
||||
@ -1641,7 +1642,7 @@ class TestIdentities(TestCase):
|
||||
|
||||
def R(seqn):
|
||||
"Regular generator"
|
||||
for i in seqn:
|
||||
for i in seqn: # noqa: UP028
|
||||
yield i
|
||||
|
||||
|
||||
@ -1655,7 +1656,7 @@ class G:
|
||||
return self.seqn[i]
|
||||
|
||||
|
||||
class I:
|
||||
class I: # noqa: E742
|
||||
"Sequence using iterator protocol"
|
||||
|
||||
def __init__(self, seqn):
|
||||
@ -1681,7 +1682,7 @@ class Ig:
|
||||
self.i = 0
|
||||
|
||||
def __iter__(self):
|
||||
for val in self.seqn:
|
||||
for val in self.seqn: # noqa: UP028
|
||||
yield val
|
||||
|
||||
|
||||
@ -1743,7 +1744,7 @@ from itertools import chain
|
||||
|
||||
def L(seqn):
|
||||
"Test multiple tiers of iterators"
|
||||
return chain(map(lambda x: x, R(Ig(G(seqn)))))
|
||||
return chain(map(lambda x: x, R(Ig(G(seqn))))) # noqa: C417
|
||||
|
||||
|
||||
class TestVariousIteratorArgs(TestCase):
|
||||
@ -1909,7 +1910,7 @@ def powerset(U):
|
||||
def cube(n):
|
||||
"""Graph of n-dimensional hypercube."""
|
||||
singletons = [frozenset([x]) for x in range(n)]
|
||||
return dict(
|
||||
return dict( # noqa: C404
|
||||
[(x, frozenset([x ^ s for s in singletons])) for x in powerset(range(n))]
|
||||
)
|
||||
|
||||
@ -1946,7 +1947,7 @@ def faces(G):
|
||||
f.add(frozenset([v1, v2, v3, v4]))
|
||||
else:
|
||||
for v5 in G[v4]:
|
||||
if v5 == v3 or v5 == v2:
|
||||
if v5 == v3 or v5 == v2: # noqa: SIM109
|
||||
continue
|
||||
if v1 in G[v5]:
|
||||
f.add(frozenset([v1, v2, v3, v4, v5]))
|
||||
|
@ -1,8 +1,5 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
@ -1,10 +1,5 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from jit.myfunction_a import my_function_a
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
# flake8: noqa: B950
|
||||
"""Test op correctness by comparing with PyTorch results.
|
||||
|
||||
## Usage
|
||||
@ -32,14 +33,13 @@ wrangler function. See `_mean_input_wrangler` for an example.
|
||||
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
|
||||
to one overload.
|
||||
"""
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Collection, Optional
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
import numpy as np
|
||||
@ -51,6 +51,10 @@ from torch.testing._internal import common_methods_invocations
|
||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection
|
||||
|
||||
|
||||
# Create a copy of the op_db to modify
|
||||
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TestGraphUtils(TestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m, guards = torchdynamo.export( # noqa: F841©
|
||||
m, guards = torchdynamo.export( # noqa: F841
|
||||
m,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
|
@ -6,9 +6,17 @@ import torch
|
||||
import torch.distributed as c10d
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
log = logging.getLogger("log")
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(FORMAT)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log.addHandler(handler)
|
||||
log.propagate = False # Prevent log duplication
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -29,14 +37,14 @@ if __name__ == "__main__":
|
||||
|
||||
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
|
||||
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
|
||||
logging.info("Running first allreduce")
|
||||
log.info("Running first allreduce")
|
||||
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
|
||||
if rank == 0:
|
||||
logging.info("Running second allreduce only on rank 0")
|
||||
log.info("Running second allreduce only on rank 0")
|
||||
work = process_group.allreduce(torch.rand(10).cuda(rank))
|
||||
logging.info("Waiting for allreduce to complete...")
|
||||
log.info("Waiting for allreduce to complete...")
|
||||
work.wait()
|
||||
logging.info("Second allreduce successful: %s", work.is_success())
|
||||
log.info("Second allreduce successful: %s", work.is_success())
|
||||
else:
|
||||
logging.info("Aborting all other ranks.")
|
||||
log.info("Aborting all other ranks.")
|
||||
os.abort()
|
||||
|
@ -181,7 +181,7 @@ class Pair(NamedTuple):
|
||||
|
||||
|
||||
# for testing pytrees
|
||||
class Foo: # noqa: B209
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
@ -38,12 +38,13 @@ from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
||||
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
||||
from quantization.core.test_backend_config import TestBackendConfig # noqa: F401
|
||||
from quantization.core.test_utils import TestUtils # noqa: F401
|
||||
log = logging.getLogger(__name__)
|
||||
try:
|
||||
# This test has extra data dependencies, so in some environments, e.g. Meta internal
|
||||
# Buck, it has its own test runner.
|
||||
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
|
||||
# using eager mode.
|
||||
@ -77,7 +78,7 @@ try:
|
||||
except ImportError as e:
|
||||
# In FBCode we separate FX out into a separate target for the sake of dev
|
||||
# velocity. These are covered by a separate test target `quantization_fx`
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# PyTorch 2 Export Quantization
|
||||
try:
|
||||
@ -99,7 +100,7 @@ try:
|
||||
except ImportError as e:
|
||||
# In FBCode we separate PT2 out into a separate target for the sake of dev
|
||||
# velocity. These are covered by a separate test target `quantization_pt2e`
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
try:
|
||||
from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401
|
||||
@ -108,7 +109,7 @@ try:
|
||||
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows # noqa: F401
|
||||
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# Test the model report module
|
||||
try:
|
||||
@ -120,19 +121,19 @@ try:
|
||||
from quantization.fx.test_model_report_fx import TestFxDetectOutliers # noqa: F401
|
||||
from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# Equalization for FX mode
|
||||
try:
|
||||
from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# Backward Compatibility. Tests serialization and BC for quantized modules.
|
||||
try:
|
||||
from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# JIT Graph Mode Quantization
|
||||
from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401
|
||||
@ -151,29 +152,29 @@ from quantization.ao_migration.test_ao_migration import TestAOMigrationNNIntrins
|
||||
try:
|
||||
from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
# Experimental functionality
|
||||
try:
|
||||
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
log.warning(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -595,7 +595,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||
print(f"log file: {log_file}")
|
||||
yield root_logger
|
||||
except Exception as e:
|
||||
logging.exception("Fatal exception")
|
||||
logging.exception("Fatal exception") # noqa: LOG015
|
||||
logging_record_exception(e)
|
||||
print(f"log file: {log_file}")
|
||||
sys.exit(1)
|
||||
@ -603,7 +603,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||
# You could logging.debug here to suppress the backtrace
|
||||
# entirely, but there is no reason to hide it from technically
|
||||
# savvy users.
|
||||
logging.info("", exc_info=True)
|
||||
logging.info("", exc_info=True) # noqa: LOG015
|
||||
logging_record_exception(e)
|
||||
print(f"log file: {log_file}")
|
||||
sys.exit(1)
|
||||
|
@ -112,7 +112,7 @@ except ModuleNotFoundError:
|
||||
try:
|
||||
import torch._logging
|
||||
import torch._numpy as tnp
|
||||
from torch._guards import detect_fake_mode # noqa: F401n
|
||||
from torch._guards import detect_fake_mode # noqa: F401
|
||||
from torch._logging import LazyString
|
||||
|
||||
from . import config
|
||||
|
@ -939,7 +939,7 @@ class HalideKernel(SIMDKernel):
|
||||
|
||||
# group the expression by variables used
|
||||
offset = sympy.S.Zero
|
||||
split_expr = {s: sympy.S.Zero for s in symbols}
|
||||
split_expr = dict.fromkeys(symbols, sympy.S.Zero)
|
||||
split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
|
||||
index = sympy.expand(self.rename_indexing(index))
|
||||
for part in index.args if isinstance(index, sympy.Add) else [index]:
|
||||
|
@ -1,4 +1,4 @@
|
||||
import os # noqa: C101
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
@ -482,7 +482,7 @@ class StorageWeakRefWrapper:
|
||||
|
||||
@classmethod
|
||||
def from_weakref_and_data_ptr(
|
||||
cls: type[S],
|
||||
cls: type[StorageWeakRefWrapper],
|
||||
cdata: Any,
|
||||
data_ptr: int,
|
||||
extra_ref_check: Optional[Callable[[], bool]] = None,
|
||||
|
@ -149,7 +149,7 @@ def grouped_gemm_lowering(
|
||||
has_bias=[bias is not None for bias in b],
|
||||
trans_w=True,
|
||||
epilogue_creator=None,
|
||||
act_mapping={num: x for num in range(num_gemm)},
|
||||
act_mapping=dict.fromkeys(range(num_gemm), x),
|
||||
)
|
||||
|
||||
input_nodes = [x, *w]
|
||||
|
@ -3029,7 +3029,7 @@ class Scheduler:
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
fusion_log.debug("fuse_nodes_once, candidates:")
|
||||
for node in fused_nodes:
|
||||
fusion_log.debug(" " + node.debug_str_short()) # noqa: G003
|
||||
fusion_log.debug(" %s", node.debug_str_short())
|
||||
|
||||
# These are potential fusions which we are async compiling,
|
||||
# and which we will benchmark profitability of.
|
||||
|
@ -69,7 +69,7 @@ try:
|
||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
||||
except ValueError as e:
|
||||
if "'not_implemented' not registered" in str(e):
|
||||
import logging as not_implemented_log
|
||||
not_implemented_log = logging.getLogger(__name__ + ".not_implemented")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@ -228,7 +228,7 @@ def convert_pt2e(
|
||||
# for detailed explanation of output quantized model
|
||||
quantized_model = convert_pt2e(prepared_model)
|
||||
|
||||
""" # flake8: noqa
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
|
||||
if not isinstance(use_reference_representation, bool):
|
||||
raise ValueError(
|
||||
|
@ -358,13 +358,11 @@ class SACEstimator(TorchDispatchMode):
|
||||
output_ids = tuple(hash(st) for st in out_storages)
|
||||
# 4. If the function is not inplace, return
|
||||
if not is_inplace(func):
|
||||
return curr_idx, output_ids, {mod_fqn: () for mod_fqn in active_mod_fqns}
|
||||
return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ())
|
||||
|
||||
op_idx = curr_idx
|
||||
# 5. Initialize the parent op ids of the inplace op for each of the active modules
|
||||
mod_op_parent_idxs: dict[str, int] = {
|
||||
mod_fqn: -1 for mod_fqn in active_mod_fqns
|
||||
}
|
||||
mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1)
|
||||
for i, d in enumerate(self._sac_metadata):
|
||||
# 6. Find the first occurence of a tensor corresponding to each module that
|
||||
# shares the same storage as the current tensor
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa C101
|
||||
import itertools
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Union
|
||||
|
@ -6,6 +6,7 @@ import sys
|
||||
import types
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
@ -319,34 +320,34 @@ class _RemoteModule(nn.Module):
|
||||
def add_module(self, name: str, module: Optional[Module]) -> None:
|
||||
_raise_not_supported(self.add_module.__name__)
|
||||
|
||||
def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
|
||||
def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.apply.__name__)
|
||||
|
||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
||||
def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.cuda.__name__)
|
||||
|
||||
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
||||
def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.ipu.__name__)
|
||||
|
||||
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
||||
def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.xpu.__name__)
|
||||
|
||||
def cpu(self: T) -> T: # type: ignore[return]
|
||||
def cpu(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.cpu.__name__)
|
||||
|
||||
def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
|
||||
def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.type.__name__)
|
||||
|
||||
def float(self: T) -> T: # type: ignore[return]
|
||||
def float(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.float.__name__)
|
||||
|
||||
def double(self: T) -> T: # type: ignore[return]
|
||||
def double(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.double.__name__)
|
||||
|
||||
def half(self: T) -> T: # type: ignore[return]
|
||||
def half(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.half.__name__)
|
||||
|
||||
def bfloat16(self: T) -> T: # type: ignore[return]
|
||||
def bfloat16(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.bfloat16.__name__)
|
||||
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
|
||||
@ -428,19 +429,19 @@ class _RemoteModule(nn.Module):
|
||||
):
|
||||
_raise_not_supported(self.named_modules.__name__)
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
def train(self, mode: bool = True) -> Self:
|
||||
return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]
|
||||
|
||||
def eval(self: T) -> T:
|
||||
def eval(self) -> Self:
|
||||
return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]
|
||||
|
||||
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
|
||||
def requires_grad_(self, requires_grad: bool = True) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.requires_grad_.__name__)
|
||||
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
_raise_not_supported(self.zero_grad.__name__)
|
||||
|
||||
def share_memory(self: T) -> T: # type: ignore[return]
|
||||
def share_memory(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.share_memory.__name__)
|
||||
|
||||
def extra_repr(self) -> str: # type: ignore[return]
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# flake8: noqa
|
||||
|
||||
from .binary import _apply_native_binary, _is_native_binary
|
||||
from .core import is_masked_tensor, MaskedTensor
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa C101
|
||||
# flake8: noqa: B950
|
||||
"""This module implements the user facing API for flex_attention in PyTorch."""
|
||||
import functools
|
||||
import inspect
|
||||
|
@ -1004,7 +1004,7 @@ class Module:
|
||||
|
||||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[["Module"], None]) -> T:
|
||||
def apply(self, fn: Callable[["Module"], None]) -> Self:
|
||||
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
|
||||
|
||||
Typical use includes initializing the parameters of a model
|
||||
@ -1045,7 +1045,7 @@ class Module:
|
||||
fn(self)
|
||||
return self
|
||||
|
||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
|
||||
def cuda(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||
r"""Move all model parameters and buffers to the GPU.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -1064,7 +1064,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.cuda(device))
|
||||
|
||||
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:
|
||||
def ipu(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||
r"""Move all model parameters and buffers to the IPU.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -1083,7 +1083,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.ipu(device))
|
||||
|
||||
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
|
||||
def xpu(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||
r"""Move all model parameters and buffers to the XPU.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -1102,7 +1102,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.xpu(device))
|
||||
|
||||
def mtia(self: T, device: Optional[Union[int, device]] = None) -> T:
|
||||
def mtia(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||
r"""Move all model parameters and buffers to the MTIA.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -1121,7 +1121,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.mtia(device))
|
||||
|
||||
def cpu(self: T) -> T:
|
||||
def cpu(self) -> Self:
|
||||
r"""Move all model parameters and buffers to the CPU.
|
||||
|
||||
.. note::
|
||||
@ -1132,7 +1132,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.cpu())
|
||||
|
||||
def type(self: T, dst_type: Union[dtype, str]) -> T:
|
||||
def type(self, dst_type: Union[dtype, str]) -> Self:
|
||||
r"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
.. note::
|
||||
@ -1146,7 +1146,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.type(dst_type))
|
||||
|
||||
def float(self: T) -> T:
|
||||
def float(self) -> Self:
|
||||
r"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||
|
||||
.. note::
|
||||
@ -1157,7 +1157,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.float() if t.is_floating_point() else t)
|
||||
|
||||
def double(self: T) -> T:
|
||||
def double(self) -> Self:
|
||||
r"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
.. note::
|
||||
@ -1168,7 +1168,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.double() if t.is_floating_point() else t)
|
||||
|
||||
def half(self: T) -> T:
|
||||
def half(self) -> Self:
|
||||
r"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
.. note::
|
||||
@ -1179,7 +1179,7 @@ class Module:
|
||||
"""
|
||||
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
|
||||
|
||||
def bfloat16(self: T) -> T:
|
||||
def bfloat16(self) -> Self:
|
||||
r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
|
||||
|
||||
.. note::
|
||||
@ -1191,8 +1191,8 @@ class Module:
|
||||
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
|
||||
|
||||
def to_empty(
|
||||
self: T, *, device: Optional[DeviceLikeType], recurse: bool = True
|
||||
) -> T:
|
||||
self, *, device: Optional[DeviceLikeType], recurse: bool = True
|
||||
) -> Self:
|
||||
r"""Move the parameters and buffers to the specified device without copying storage.
|
||||
|
||||
Args:
|
||||
@ -2837,7 +2837,7 @@ class Module:
|
||||
memo, submodule_prefix, remove_duplicate
|
||||
)
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
def train(self, mode: bool = True) -> Self:
|
||||
r"""Set the module in training mode.
|
||||
|
||||
This has an effect only on certain modules. See the documentation of
|
||||
@ -2859,7 +2859,7 @@ class Module:
|
||||
module.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self: T) -> T:
|
||||
def eval(self) -> Self:
|
||||
r"""Set the module in evaluation mode.
|
||||
|
||||
This has an effect only on certain modules. See the documentation of
|
||||
@ -2877,7 +2877,7 @@ class Module:
|
||||
"""
|
||||
return self.train(False)
|
||||
|
||||
def requires_grad_(self: T, requires_grad: bool = True) -> T:
|
||||
def requires_grad_(self, requires_grad: bool = True) -> Self:
|
||||
r"""Change if autograd should record operations on parameters in this module.
|
||||
|
||||
This method sets the parameters' :attr:`requires_grad` attributes
|
||||
@ -2928,7 +2928,7 @@ class Module:
|
||||
p.grad.requires_grad_(False)
|
||||
p.grad.zero_()
|
||||
|
||||
def share_memory(self: T) -> T:
|
||||
def share_memory(self) -> Self:
|
||||
r"""See :meth:`torch.Tensor.share_memory_`."""
|
||||
return self._apply(lambda t: t.share_memory_())
|
||||
|
||||
|
@ -70,7 +70,7 @@ def from_dynamic_axes_to_dynamic_shapes(
|
||||
raise ValueError(
|
||||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||
)
|
||||
dynamic_shapes[input_name] = {k: torch.export.Dim.DYNAMIC for k in axes}
|
||||
dynamic_shapes[input_name] = dict.fromkeys(axes, torch.export.Dim.DYNAMIC)
|
||||
elif axes is None:
|
||||
dynamic_shapes[input_name] = None
|
||||
else:
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""torch.ops.aten operators under the `core` module."""
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""torch.ops.aten operators under the `core` module."""
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
# flake8: noqa: B950
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -12,7 +11,6 @@ import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import (
|
||||
BOOL,
|
||||
FLOAT,
|
||||
INT64,
|
||||
IntType,
|
||||
TensorType,
|
||||
)
|
||||
|
Reference in New Issue
Block a user