mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This is a lot of files changed! Don't panic! Here's how it works: * Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file. * When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded. * The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors. * Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list. * Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves. * torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state. * There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many. In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file. The codemod was done with this script authored by GPT-4: ``` import glob exclude_patterns = [ ... ] for pattern in exclude_patterns: for filepath in glob.glob(pattern, recursive=True): if filepath.endswith('.py'): with open(filepath, 'r+') as f: content = f.read() f.seek(0, 0) f.write('# mypy: ignore-errors\n\n' + content) ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414 Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
# mypy: ignore-errors
|
|
|
|
import contextlib
|
|
import functools
|
|
import inspect
|
|
|
|
import torch
|
|
|
|
|
|
# Test whether hardware BF32 math mode enabled. It is enabled only on:
|
|
# - MKLDNN is available
|
|
# - BF16 is supported by MKLDNN
|
|
def bf32_is_not_fp32():
|
|
if not torch.backends.mkldnn.is_available():
|
|
return False
|
|
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
return False
|
|
return True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def bf32_off():
|
|
old_matmul_precision = torch.get_float32_matmul_precision()
|
|
try:
|
|
torch.set_float32_matmul_precision("highest")
|
|
yield
|
|
finally:
|
|
torch.set_float32_matmul_precision(old_matmul_precision)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def bf32_on(self, bf32_precision=1e-5):
|
|
old_matmul_precision = torch.get_float32_matmul_precision()
|
|
old_precision = self.precision
|
|
try:
|
|
torch.set_float32_matmul_precision("medium")
|
|
self.precision = bf32_precision
|
|
yield
|
|
finally:
|
|
torch.set_float32_matmul_precision(old_matmul_precision)
|
|
self.precision = old_precision
|
|
|
|
|
|
# This is a wrapper that wraps a test to run this test twice, one with
|
|
# allow_bf32=True, another with allow_bf32=False. When running with
|
|
# allow_bf32=True, it will use reduced precision as specified by the
|
|
# argument
|
|
def bf32_on_and_off(bf32_precision=1e-5):
|
|
def with_bf32_disabled(self, function_call):
|
|
with bf32_off():
|
|
function_call()
|
|
|
|
def with_bf32_enabled(self, function_call):
|
|
with bf32_on(self, bf32_precision):
|
|
function_call()
|
|
|
|
def wrapper(f):
|
|
params = inspect.signature(f).parameters
|
|
arg_names = tuple(params.keys())
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
for k, v in zip(arg_names, args):
|
|
kwargs[k] = v
|
|
cond = bf32_is_not_fp32()
|
|
if "device" in kwargs:
|
|
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
|
|
if "dtype" in kwargs:
|
|
cond = cond and (kwargs["dtype"] == torch.float)
|
|
if cond:
|
|
with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
|
|
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
|
|
else:
|
|
f(**kwargs)
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|