mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Use strict to toggle strict options in MYPYSTRICT (#118479)
As we force a specific version of mypy, it's OK to use the agglomerated flag. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118479 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #118414, #118418, #118432, #118467, #118468, #118469, #118475
This commit is contained in:
committed by
PyTorch MergeBot
parent
ecca533872
commit
119b66ba16
@ -79,7 +79,7 @@ def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
|
||||
assert isinstance(
|
||||
jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)
|
||||
), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
|
||||
jit_model.save(artifact_path)
|
||||
jit_model.save(artifact_path) # type: ignore[call-arg]
|
||||
|
||||
# Cleanup now that we have the actual serialized model.
|
||||
os.remove(module_path)
|
||||
|
@ -17,21 +17,8 @@ show_column_numbers = True
|
||||
warn_no_return = True
|
||||
disallow_any_unimported = True
|
||||
|
||||
# Across versions of mypy, the flags toggled by --strict vary. To ensure
|
||||
# we have reproducible type check, we instead manually specify the flags
|
||||
warn_unused_configs = True
|
||||
disallow_any_generics = True
|
||||
disallow_subclassing_any = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
disallow_incomplete_defs = True
|
||||
check_untyped_defs = True
|
||||
disallow_untyped_decorators = True
|
||||
no_implicit_optional = True
|
||||
warn_redundant_casts = True
|
||||
warn_return_any = True
|
||||
strict = True
|
||||
implicit_reexport = False
|
||||
strict_equality = True
|
||||
|
||||
# do not reenable this:
|
||||
# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657
|
||||
|
@ -41,7 +41,7 @@ def main() -> None:
|
||||
|
||||
for func_name in sorted(torch.masked._ops.__all__):
|
||||
func = getattr(torch.masked._ops, func_name)
|
||||
func_doc = torch.masked._generate_docstring(func)
|
||||
func_doc = torch.masked._generate_docstring(func) # type: ignore[no-untyped-call, attr-defined]
|
||||
_new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')
|
||||
|
||||
new_content = "\n".join(_new_content)
|
||||
|
@ -29,7 +29,7 @@ from typing import (
|
||||
|
||||
import torch
|
||||
|
||||
if torch._running_with_deploy():
|
||||
if torch._running_with_deploy(): # type: ignore[no-untyped-call]
|
||||
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||
|
||||
import optree
|
||||
|
@ -13,7 +13,7 @@ from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valg
|
||||
__all__ = ["Timer", "timer", "Language"]
|
||||
|
||||
|
||||
if torch.backends.cuda.is_built() and torch.cuda.is_available():
|
||||
if torch.backends.cuda.is_built() and torch.cuda.is_available(): # type: ignore[no-untyped-call]
|
||||
def timer() -> float:
|
||||
torch.cuda.synchronize()
|
||||
return timeit.default_timer()
|
||||
|
@ -465,7 +465,7 @@ class GlobalsBridge:
|
||||
path = os.path.join(self._data_dir, f"{name}.pt")
|
||||
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
|
||||
with open(path, "wb") as f:
|
||||
torch.jit.save(wrapped_value.value, f)
|
||||
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -502,7 +502,7 @@ class _ValgrindWrapper:
|
||||
).returncode
|
||||
|
||||
self._build_type: Optional[str] = None
|
||||
build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show())
|
||||
build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call]
|
||||
if build_search is not None:
|
||||
self._build_type = build_search.groups()[0].split(",")[0]
|
||||
|
||||
|
@ -55,7 +55,7 @@ DECOMPOSITION_UTIL_FILE_NAME = "decomposition_registry_util.cpp"
|
||||
|
||||
def gen_serialized_decompisitions() -> str:
|
||||
return "\n".join(
|
||||
[scripted_func.code for scripted_func in decomposition_table.values()]
|
||||
[scripted_func.code for scripted_func in decomposition_table.values()] # type: ignore[misc]
|
||||
)
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ def gen_decomposition_mappings() -> str:
|
||||
decomposition_mappings = []
|
||||
for schema, scripted_func in decomposition_table.items():
|
||||
decomposition_mappings.append(
|
||||
' {"' + schema + '", "' + scripted_func.name + '"},'
|
||||
' {"' + schema + '", "' + scripted_func.name + '"},' # type: ignore[operator]
|
||||
)
|
||||
return "\n".join(decomposition_mappings)
|
||||
|
||||
|
@ -89,7 +89,7 @@ def serialize_functions() -> None:
|
||||
for (
|
||||
key,
|
||||
kwargs,
|
||||
) in _get_sfdp_patterns():
|
||||
) in _get_sfdp_patterns(): # type: ignore[no-untyped-call]
|
||||
pattern_name = kwargs["search_fn"].__name__
|
||||
gen_kwargs = {
|
||||
key: kwargs[key]
|
||||
@ -134,5 +134,5 @@ def serialize_functions() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch._subclasses.FakeTensorMode():
|
||||
with torch._subclasses.FakeTensorMode(): # type: ignore[no-untyped-call]
|
||||
serialize_functions()
|
||||
|
Reference in New Issue
Block a user