Compare commits

...

5 Commits

Author SHA1 Message Date
19e52556fa Fix unintended updates to submodules 2025-11-07 08:57:52 -08:00
1d43f171d6 Fix signals 2025-11-07 06:51:15 -08:00
910471526d Type functions 2025-11-07 06:49:52 -08:00
edd611f3b0 [CI] Upgrade Ubuntu 24.04 for XPU CI tests (#162475)
As the title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162475
Approved by: https://github.com/EikanWang, https://github.com/atalman
2025-11-07 14:05:16 +00:00
aded2ebb90 [3/N] Add return types of Python functions (#167287)
This PR adds return types to some Python functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167287
Approved by: https://github.com/mlazos
2025-11-07 13:50:33 +00:00
57 changed files with 739 additions and 544 deletions

View File

@ -207,9 +207,9 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
GCC_VERSION=13
VISION=yes
XPU_VERSION=2025.2
NINJA_VERSION=1.9.0

View File

@ -9,7 +9,7 @@ set -xe
function install_ubuntu() {
. /etc/os-release
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
echo "Ubuntu version ${VERSION_CODENAME} not supported"
exit
fi
@ -35,25 +35,24 @@ function install_ubuntu() {
# The xpu-smi packages
apt-get install -y flex bison xpu-smi
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Compute and Media Runtimes
# Compute and Media Runtimes
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
apt-get install -y \
intel-opencl-icd intel-level-zero-gpu level-zero \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
else # rolling driver
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
else # jammy
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
fi
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
# Install Intel Support Packages
apt-get install -y ${XPU_PACKAGES}
@ -66,7 +65,7 @@ function install_ubuntu() {
function install_rhel() {
. /etc/os-release
if [[ "${ID}" == "rhel" ]]; then
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
echo "RHEL version ${VERSION_ID} not supported"
exit
fi
@ -147,7 +146,7 @@ function install_sles() {
XPU_DRIVER_VERSION=""
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Use GPU driver LTS releases
XPU_DRIVER_VERSION="/lts/2350"
XPU_DRIVER_VERSION="/lts/2523"
fi
# Default use Intel® oneAPI Deep Learning Essentials 2025.1

View File

@ -68,8 +68,8 @@ jobs:
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,

View File

@ -83,8 +83,8 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -117,7 +117,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
build-environment: linux-noble-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
@ -137,7 +137,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
build-environment: linux-noble-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}

View File

@ -342,16 +342,16 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# This should sync with the build in xpu.yml but xpu uses a larger runner
# sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },

View File

@ -47,15 +47,15 @@ jobs:
]}
secrets: inherit
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -74,17 +74,17 @@ jobs:
]}
secrets: inherit
linux-jammy-xpu-n-py3_10-test:
name: linux-jammy-xpu-n-py3.10
linux-noble-xpu-n-py3_10-test:
name: linux-noble-xpu-n-py3.10
uses: ./.github/workflows/_xpu-test.yml
needs: linux-jammy-xpu-n-py3_10-build
needs: linux-noble-xpu-n-py3_10-build
permissions:
id-token: write
contents: read
with:
build-environment: linux-jammy-xpu-n-py3.10
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
build-environment: linux-noble-xpu-n-py3.10
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
secrets: inherit
windows-xpu-n-1-build:

View File

@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
// supported in conv.
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
if (groups > 1 && weight_zero_points.numel() > 1)
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
dnnl::primitive_attr pattr;
bool src_need_zp = (act_zero_point != 0);

View File

@ -1941,6 +1941,7 @@ if(BUILD_TEST)
foreach(test_src ${Caffe2_XPU_TEST_SRCS})
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
torch_compile_options(${test_name})
target_link_libraries(${test_name} torch_library gtest_main)
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})

View File

@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError
ret = variables.UserFunctionVariable(
polyfills.builtins.iter_
polyfills.builtins.iter_ # type: ignore[arg-type]
).call_function(tx, [obj, *args], {})
if args:

File diff suppressed because it is too large Load Diff

View File

@ -590,7 +590,7 @@ class FilterVariable(IteratorVariable):
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate
polyfills.predicate # type: ignore[arg-type]
).call_function(tx, [res], {})
if pred_res.as_python_constant():
return item

View File

@ -1498,6 +1498,7 @@ class NamedTupleVariable(TupleVariable):
variables.UserDefinedClassVariable(self.tuple_cls),
)
elif isinstance(method, staticmethod):
# pyrefly: ignore[bad-argument-type]
return UserFunctionVariable(method.__func__)
elif inspect.isfunction(method):
return UserMethodVariable(method, self)

View File

@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
)
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
self.value, tx, self.source, args, kwargs
# pyrefly: ignore[bad-argument-type]
self.value,
tx,
self.source,
args,
kwargs,
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
@ -1429,7 +1434,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2

View File

@ -279,7 +279,7 @@ def _hide_source_ranges() -> Iterator[None]:
torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
def enable_onednn_fusion(enabled: bool):
def enable_onednn_fusion(enabled: bool) -> None:
"""Enable or disables onednn JIT fusion based on the parameter `enabled`."""
torch._C._jit_set_llga_enabled(enabled)

View File

@ -162,7 +162,7 @@ def _get_builtin_table():
return _builtin_table
_builtin_table = {}
def register_all(mod):
def register_all(mod) -> None:
for name in dir(mod):
v = getattr(mod, name)
if (
@ -196,7 +196,7 @@ def _get_builtin_table():
return _builtin_table
def _register_builtin(fn, op):
def _register_builtin(fn, op) -> None:
_get_builtin_table()[id(fn)] = op

View File

@ -116,7 +116,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
return True
def visit_Assign(self, node):
def visit_Assign(self, node) -> None:
"""Store assignment state when assigning to a Call Node.
If we're visiting a Call Node (the right-hand side of an
@ -139,7 +139,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
self.generic_visit(node)
self.visiting_class_level_ann = False
def visit_AnnAssign(self, node):
def visit_AnnAssign(self, node) -> None:
"""Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
It checks if it conforms to our attribute annotation rules."""
@ -194,7 +194,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
stacklevel=2,
)
def visit_Call(self, node):
def visit_Call(self, node) -> None:
"""Determine if a Call node is 'torch.jit.annotate' in __init__.
Visit a Call node in an ``nn.Module``'s ``__init__``

View File

@ -3,7 +3,7 @@ import torch
from torch._ops import OpOverload, OpOverloadPacket
def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
def _register_decomposition(op: OpOverload, graph: torch._C.Graph) -> None:
assert not isinstance(op, OpOverloadPacket), (
f"Must pass specific op overload, not overload packet, found {op}"
)

View File

@ -20,7 +20,7 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
def check_decomposition_has_type_annotations(f):
def check_decomposition_has_type_annotations(f) -> None:
inspect_empty = inspect._empty # type: ignore[attr-defined]
sig = inspect.signature(f)
for param in sig.parameters.values():

View File

@ -125,7 +125,7 @@ def freeze(
def run_frozen_optimizations(
mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None
):
) -> None:
r"""
Run a series of optimizations looking for patterns that occur in frozen graphs.

View File

@ -83,7 +83,7 @@ def fuser(name):
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
def _get_differentiable_graph_node(node, diff_node):
def _get_differentiable_graph_node(node, diff_node) -> None:
if node.kind() == "prim::DifferentiableGraph":
diff_node.append(node)
else:

View File

@ -9,7 +9,7 @@ class _InsertPoint:
self,
insert_point_graph: torch._C.Graph,
insert_point: Union[torch._C.Node, torch._C.Block],
):
) -> None:
self.insert_point = insert_point
self.g = insert_point_graph
self.guard = None

View File

@ -85,7 +85,7 @@ if _IS_MONKEYTYPE_INSTALLED:
class JitTypeTraceStoreLogger(CallTraceStoreLogger):
"""A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore."""
def __init__(self, store: CallTraceStore):
def __init__(self, store: CallTraceStore) -> None:
super().__init__(store)
def log(self, trace: CallTrace) -> None:
@ -100,7 +100,7 @@ if _IS_MONKEYTYPE_INSTALLED:
# value is list of all CallTrace
self.trace_records: dict[str, list] = defaultdict(list)
def add(self, traces: Iterable[CallTrace]):
def add(self, traces: Iterable[CallTrace]) -> None:
for t in traces:
qualified_name = get_qualified_name(t.func)
self.trace_records[qualified_name].append(t)
@ -145,7 +145,7 @@ if _IS_MONKEYTYPE_INSTALLED:
return self.consolidate_types(qualified_name)
class JitTypeTraceConfig(monkeytype.config.Config):
def __init__(self, s: JitTypeTraceStore):
def __init__(self, s: JitTypeTraceStore) -> None:
super().__init__()
self.s = s

View File

@ -152,7 +152,7 @@ def _get_valid_constant(attr, v, owner_type):
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len):
def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None:
super().__init__(source, filename, file_lineno, leading_whitespace_len)
@ -454,7 +454,7 @@ concrete_type_store = ConcreteTypeStore()
def create_methods_and_properties_from_stubs(
concrete_type, method_stubs, property_stubs
):
) -> None:
method_defs = [m.def_ for m in method_stubs]
method_rcbs = [m.resolution_callback for m in method_stubs]
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
@ -467,7 +467,7 @@ def create_methods_and_properties_from_stubs(
)
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) -> None:
hook_defs = [h.def_ for h in hook_stubs]
hook_rcbs = [h.resolution_callback for h in hook_stubs]
@ -571,7 +571,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
ignored_properties = jit_ignored_properties(nn_module)
def init_fn(script_module):
def init_fn(script_module) -> None:
# Initialize the ScriptModule:
# 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
for name in concrete_type.get_attributes():
@ -725,7 +725,7 @@ def script_model_defines_attr(script_model, attr):
return script_attr != default_attr
def add_python_attr_to_scripted_model(script_model, orig, attr):
def add_python_attr_to_scripted_model(script_model, orig, attr) -> None:
if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
setattr(script_model, attr, getattr(orig, attr))
@ -777,7 +777,7 @@ def get_overload_name_mapping(overload_info):
return overload_name_mappings
def _check_no_signature(func):
def _check_no_signature(func) -> None:
signature = torch.jit.annotations.get_signature(
func, None, fake_range(), inspect.ismethod(func)
)
@ -807,7 +807,7 @@ def make_stubs_for_overloads(overload_info):
return overload_stubs
def check_module_initialized(mod):
def check_module_initialized(mod) -> None:
assert isinstance(mod, torch.nn.Module)
if not hasattr(mod, "_parameters"):
raise RuntimeError(
@ -1002,7 +1002,7 @@ def wrap_cpp_class(cpp_class):
def wrap_cpp_module(cpp_module):
"""Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
def init_fn(script_module):
def init_fn(script_module) -> None:
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
setattr(script_module, name, wrap_cpp_module(cpp_module))
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
@ -1037,7 +1037,7 @@ def lazy_bind(concrete_type, unbound_method):
"""
def lazy_binding_method(cpp_module, *args):
def init_fn(script_module):
def init_fn(script_module) -> None:
orig_class = concrete_type.py_class
# Copy @ignored/@unused methods from the original module to the new one.

View File

@ -18,7 +18,7 @@ from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device
def save(m, f, _extra_files=None):
def save(m, f, _extra_files=None) -> None:
r"""
Save an offline version of this module for use in a separate process.
@ -213,7 +213,7 @@ def jit_module_from_flatbuffer(f):
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
def save_jit_module_to_flatbuffer(m, f, _extra_files=None) -> None:
r"""
Save an offline version of this module for use in a separate process.

View File

@ -41,18 +41,18 @@ class EnabledProxy:
return False
raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.")
def __bool__(self):
def __bool__(self) -> bool:
return self.enabled
_enabled = EnabledProxy()
def disable():
def disable() -> None:
_enabled.enabled = False
def enable():
def enable() -> None:
_enabled.enabled = True
@ -67,7 +67,7 @@ _script_classes: dict[type[Any], type[Any]] = {}
_name_to_pyclass: dict[str, type[Any]] = {}
def _add_script_class(python_class, script_class):
def _add_script_class(python_class, script_class) -> None:
_script_classes[python_class] = script_class
_name_to_pyclass[script_class.qualified_name()] = python_class
@ -83,7 +83,7 @@ def _get_python_class(qualified_name):
return _name_to_pyclass.get(qualified_name)
def _clear_class_state():
def _clear_class_state() -> None:
_script_classes.clear()
_name_to_pyclass.clear()
@ -108,7 +108,7 @@ def _try_get_jit_cached_overloads(key):
return None
def _set_jit_overload_cache(key, compiled_fns):
def _set_jit_overload_cache(key, compiled_fns) -> None:
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
@ -122,7 +122,7 @@ def _try_get_jit_cached_function(key):
return None
def _set_jit_function_cache(key, value):
def _set_jit_function_cache(key, value) -> None:
# only free functions currently supported
assert isinstance(value, torch.jit.ScriptFunction)
_jit_caching_layer[key] = value.qualified_name

View File

@ -68,7 +68,7 @@ from torch._ops import OpOverloadPacket
class Module:
def __init__(self, name, members):
def __init__(self, name, members) -> None:
self.name = name
self.members = members
@ -95,7 +95,7 @@ class EvalEnv:
"Await": _Await,
}
def __init__(self, rcb):
def __init__(self, rcb) -> None:
self.rcb = rcb
if torch.distributed.rpc.is_available():
# pyrefly: ignore [unsupported-operation]
@ -178,7 +178,7 @@ def get_param_names(fn, n_args):
return [str(i) for i in range(n_args)]
def check_fn(fn, loc):
def check_fn(fn, loc) -> None:
# Make sure the function definition is not a class instantiation
try:
source = dedent("".join(get_source_lines_and_file(fn)[0]))
@ -368,7 +368,7 @@ def get_enum_value_type(e: type[enum.Enum], loc):
return res
def is_tensor(ann):
def is_tensor(ann) -> bool:
if issubclass(ann, torch.Tensor):
return True
@ -397,7 +397,7 @@ def is_tensor(ann):
return False
def _fake_rcb(inp):
def _fake_rcb(inp) -> None:
return None

View File

@ -147,7 +147,7 @@ pretty_node_names.update(
class FrontendError(Exception):
def __init__(self, source_range, msg):
def __init__(self, source_range, msg) -> None:
self.source_range = source_range
self.msg = msg
@ -155,7 +155,7 @@ class FrontendError(Exception):
# call stack when the FrontendError was raised
self.error_report = torch._C.ErrorReport(self.source_range)
def __str__(self):
def __str__(self) -> str:
return self.msg + self.error_report.what().lstrip()
@ -164,7 +164,7 @@ class NotSupportedError(FrontendError):
class UnsupportedNodeError(NotSupportedError):
def __init__(self, ctx, offending_node, reason=""):
def __init__(self, ctx, offending_node, reason="") -> None:
# If we don't have a specific token, we default to length of 1
node_type = type(offending_node)
range_len = len(node_start_tokens.get(node_type, " "))
@ -229,7 +229,7 @@ def get_class_properties(cls, self_name):
def get_class_assigns(ctx, cls_ast):
assigns = []
def maybe_build_assign(builder, entry):
def maybe_build_assign(builder, entry) -> None:
nonlocal assigns
try:
assigns.append(builder(ctx, entry))
@ -385,7 +385,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
# TODO: more robust handling of recognizing ignore context manager
def is_torch_jit_ignore_context_manager(stmt):
def is_torch_jit_ignore_context_manager(stmt) -> bool:
# checks if the statement is torch.jit.ignore context manager
if isinstance(stmt.items[0].context_expr, ast.Call):
# extract torch part
@ -535,7 +535,7 @@ def build_ignore_context_manager(ctx, stmt):
outputs.append(OutputType(var_name, var_ann))
return inputs, outputs
def create_unique_name_ext(ctx, stmt):
def create_unique_name_ext(ctx, stmt) -> str:
# extension will be based on the full path filename plus
# the line number of original context manager
fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename)

View File

@ -56,7 +56,7 @@ def _load_for_lite_interpreter(f, map_location=None):
class LiteScriptModule:
def __init__(self, cpp_module):
def __init__(self, cpp_module) -> None:
self._c = cpp_module
super().__init__()

View File

@ -57,7 +57,7 @@ def _emit_schema(mod, name, schema, arg_start=0, padding=4):
def _get_tensor_ops():
def is_tensor_method(schema):
def is_tensor_method(schema) -> bool:
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]

View File

@ -5,7 +5,7 @@ from typing import Any
import torch.jit
def execWrapper(code, glob, loc):
def execWrapper(code, glob, loc) -> None:
exec(code, glob, loc)

View File

@ -35,7 +35,7 @@ def _ge(lhs: Any, rhs: Any) -> bool:
class NestedIntNode:
def __init__(self, t_id: int, coeff: int):
def __init__(self, t_id: int, coeff: int) -> None:
self.t_id = t_id
self.coeff = coeff

View File

@ -131,7 +131,7 @@ class NestedTensor(torch.Tensor):
return r
def __init__(self, values, offsets, *, lengths=None, **kwargs):
def __init__(self, values, offsets, *, lengths=None, **kwargs) -> None:
super().__init__()
self._values = values
@ -243,7 +243,7 @@ class NestedTensor(torch.Tensor):
self._values, memory_format=torch.contiguous_format
)
def __repr__(self): # type: ignore[override]
def __repr__(self) -> str: # type: ignore[override]
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""

View File

@ -400,7 +400,7 @@ def jagged_torch_function(func, *args, **kwargs):
# Handle flatten() here because it's CompositeImplicit.
if func.__name__ == "flatten":
def _flatten_sig(input, start_dim=0, end_dim=-1):
def _flatten_sig(input, start_dim=0, end_dim=-1) -> None:
pass
_, new_kwargs = normalize_function( # type: ignore[misc]
@ -466,7 +466,7 @@ def jagged_torch_function(func, *args, **kwargs):
# Handle nested-specific input validation for CompositeImplicit rms_norm
if func.__name__ == "rms_norm":
def _rms_norm_sig(input, normalized_shape, weight=None, eps=None):
def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None:
pass
_, new_kwargs = normalize_function( # type: ignore[misc]
@ -532,7 +532,7 @@ def prim_layout_default(func, *args, **kwargs):
[torch.ops.aten.size.default],
"self: jt_all",
)
def tensor_attr_unsupported_getter(func, *args, **kwargs):
def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None:
if func is torch.ops.aten.size.default:
raise RuntimeError(
"NestedTensor does not support directly calling torch.ops.aten.size; "
@ -1138,7 +1138,7 @@ def unbind_int(func, *args, **kwargs):
lengths = inp.lengths()
ragged_idx = inp._ragged_idx
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None) -> None:
# This torch._check are needed for torch.compile
# symbolic shapes processing.
# offsets and lengths are symbolic variables during compilation,
@ -2615,7 +2615,7 @@ def _nested_select_backward_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
def record_stream_default(func, *args, **kwargs):
def record_stream_default(func, *args, **kwargs) -> None:
inp = args[0]
stream = args[1]
# ensure all components live until stream computation completes

View File

@ -31,7 +31,7 @@ def _validate_sdpa_input(
dropout_p=0.0,
is_causal=False,
scale=None,
):
) -> None:
if (
not isinstance(query, NestedTensor)
or not isinstance(key, NestedTensor)
@ -364,7 +364,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, in
return cumulative_seqlen, max_seqlen, n_elem
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool:
# This function checks if a nested tensor is valid for
# use with the flash-attention and efficient_attention kernels without
# needing to call contiguous on the nested tensor input.

View File

@ -537,7 +537,7 @@ class OpRecorder(evaluator.Evaluator):
def __init__(
self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value]
):
) -> None:
self.nodes: list[ir.Node] = []
self.opset = opset
self.functions: dict[

View File

@ -92,7 +92,7 @@ class CaptureStrategy(abc.ABC):
dump: bool = False,
artifacts_dir: str | os.PathLike = ".",
timestamp: str | None = None,
):
) -> None:
"""Initialize the strategy.
Args:

View File

@ -109,7 +109,7 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor, name: str | None = None):
def __init__(self, tensor: torch.Tensor, name: str | None = None) -> None:
# Pass the tensor as the raw data to ir.Tensor's constructor
if tensor.dtype == torch.float4_e2m1fn_x2:
# Change the shape to the unpacked shape

View File

@ -211,7 +211,7 @@ class ONNXProgram:
def __init__(
self, model: ir.Model, exported_program: torch.export.ExportedProgram | None
):
) -> None:
"""Initialize the ONNX program with the specified model and exported program.
Args:
model: The ONNX model.
@ -327,7 +327,7 @@ ONNXProgram(
include_initializers: bool = True,
keep_initializers_as_inputs: bool = False,
external_data: bool | None = None,
):
) -> None:
"""Save the ONNX model to the specified destination.
When ``external_data`` is ``True`` or the model is larger than 2GB,

View File

@ -149,7 +149,7 @@ def create_torch_export_error_report(
*,
export_status: ExportStatus,
profile_result: str | None,
):
) -> None:
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Error Report\n\n")
f.write(_format_export_status(export_status))
@ -175,7 +175,7 @@ def create_onnx_export_report(
model: ir.Model | None = None,
registry: _registration.ONNXRegistry | None = None,
verification_result: str | None = None,
):
) -> None:
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Report\n\n")
f.write(_format_export_status(export_status))

View File

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
# A special value to indicate that the default value is not specified
class _Empty:
def __repr__(self):
def __repr__(self) -> str:
return "_EMPTY_DEFAULT"

View File

@ -18,7 +18,7 @@ class SymbolicTensor(ir.Value):
type: ir.TypeProtocol | None = None,
doc_string: str | None = None,
const_value: ir.TensorProtocol | None = None,
):
) -> None:
super().__init__(
name=name,
shape=shape,

View File

@ -66,7 +66,7 @@ def _patch_difflib_sequence_matcher_init():
"""
original_init = difflib.SequenceMatcher.__init__
def patched_init(self, isjunk=None, a="", b="", autojunk=True):
def patched_init(self, isjunk=None, a="", b="", autojunk=True) -> None:
original_init(self, isjunk, a, b, autojunk=False)
difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment]
@ -192,7 +192,7 @@ class Transform(abc.ABC):
def __init__(
self,
module: torch.fx.GraphModule,
):
) -> None:
"""Initialize the transform.
Args:

View File

@ -63,7 +63,7 @@ class TypePromotionSnapshot:
class TypePromotionRule(abc.ABC):
"""Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'."""
def __init__(self, namespace: str, op_name: str):
def __init__(self, namespace: str, op_name: str) -> None:
self.namespace = namespace
self.op_name = op_name
@ -74,7 +74,7 @@ class TypePromotionRule(abc.ABC):
def __hash__(self) -> int: ...
@abc.abstractmethod
def __repr__(self): ...
def __repr__(self) -> str: ...
@abc.abstractmethod
def __eq__(self, other: object) -> bool: ...
@ -128,7 +128,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
promote_args_positions: Sequence[int],
promote_kwargs_names: Sequence[str],
promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND,
):
) -> None:
"""Constructs a TypePromotionRule for elementwise operators.
Args:
@ -143,7 +143,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
self.promote_kwargs_names = promote_kwargs_names
self.promotion_kind = promotion_kind
def __repr__(self):
def __repr__(self) -> str:
return (
f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', "
f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
@ -216,7 +216,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
Rule depends on the value of the `rounding_mode` argument.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(
"aten",
"div",
@ -252,7 +252,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
namespace: str,
op_name: str,
promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND,
):
) -> None:
"""Constructs a TypePromotionRule for reduction operators.
Args:
@ -263,7 +263,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
super().__init__(namespace, op_name)
self.promotion_kind = promotion_kind
def __repr__(self):
def __repr__(self) -> str:
return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
# pyrefly: ignore [bad-override]
@ -311,7 +311,7 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool.
"""
def __init__(self, op_name: str):
def __init__(self, op_name: str) -> None:
super().__init__(
"aten",
op_name,
@ -1205,7 +1205,7 @@ class ElementwiseTypePromotionRuleSetGenerator:
class TypePromotionTable:
"""Type promotion table for torch.ops."""
def __init__(self):
def __init__(self) -> None:
self._rule_table = {}
for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET:
self.add_rule(rule)
@ -1262,7 +1262,7 @@ class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
op overload for a given op overload packet for different set of args and kwargs.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.traced_ops = []
@ -1331,7 +1331,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
self,
module: torch.fx.GraphModule,
type_promotion_table: TypePromotionTable,
):
) -> None:
super().__init__(module)
self.type_promotion_table = type_promotion_table
@ -1603,7 +1603,7 @@ class InsertTypePromotion(_pass.Transform):
self,
module: torch.fx.GraphModule,
type_promotion_table: TypePromotionTable | None = None,
):
) -> None:
super().__init__(module)
self.interpreter = _TypePromotionInterpreter(
module, type_promotion_table or TypePromotionTable()

View File

@ -810,7 +810,7 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
@_onnx_symbolic("aten::cumsum")
@symbolic_helper.parse_args("v", "i", "none")
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
def cumsum(g: jit_utils.GraphContext, input, dim, dtype) -> None:
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
@ -3332,7 +3332,9 @@ def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
def _unique2(
g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts
) -> None:
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
@ -6289,7 +6291,7 @@ def broadcast_tensors(g: jit_utils.GraphContext, self):
@_onnx_symbolic("aten::is_pinned")
def is_pinned(g: jit_utils.GraphContext, self, device=None):
def is_pinned(g: jit_utils.GraphContext, self, device=None) -> None:
# Unused by ONNX.
return None
@ -6357,7 +6359,7 @@ def prim_layout(g: jit_utils.GraphContext, self):
@_onnx_symbolic("prim::ListConstruct")
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
return None
@ -6374,12 +6376,12 @@ def prim_list_unpack(
@_onnx_symbolic("prim::TupleConstruct")
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
return None
@_onnx_symbolic("prim::Uninitialized")
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs):
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
return None

View File

@ -571,7 +571,7 @@ def export(
return None
def _is_constant_tensor_list(node):
def _is_constant_tensor_list(node) -> bool | None:
if node.kind() != "prim::Constant":
return False
output_type = node.output().type()
@ -585,7 +585,7 @@ def _is_constant_tensor_list(node):
# get generated in constant prop. So we split them back into prim::ListConstructs
def _split_tensor_list_constants(g, block):
def _split_tensor_list_constants(g, block) -> None:
for node in block.nodes():
for subblock in node.blocks():
_split_tensor_list_constants(g, subblock)
@ -722,7 +722,7 @@ def _optimize_graph(
return graph
def warn_on_static_input_change(input_states):
def warn_on_static_input_change(input_states) -> None:
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
We accept dictionaries and strings as ONNX inputs, but they should be only for
@ -932,7 +932,7 @@ def _get_param_count_list(method_graph, args_params):
return param_count_list
def _check_flatten_did_not_remove(original, jit_flattened):
def _check_flatten_did_not_remove(original, jit_flattened) -> None:
"""torch.jit._flatten removes None. Check if it did so in this case."""
def flatten(x):
@ -1286,13 +1286,13 @@ def _setup_trace_module_map(
model: torch.nn.Module | torch.jit.ScriptModule,
export_modules_as_functions: bool | Collection[type[torch.nn.Module]],
) -> set[str]:
def __register_attribute_hook():
def __register_attribute_hook() -> None:
attr_name = "_onnx_attrs"
def _track_module_attributes_forward_pre_hook(module, input):
def _track_module_attributes_forward_pre_hook(module, input) -> None:
setattr(module, attr_name, _get_module_attributes(module))
def _track_module_attributes_forward_hook(module, input, output):
def _track_module_attributes_forward_hook(module, input, output) -> None:
tracing_state = _C._get_tracing_state()
if not tracing_state:
return
@ -1359,7 +1359,7 @@ def _setup_trace_module_map(
return module_typenames
def _reset_trace_module_map():
def _reset_trace_module_map() -> None:
torch.jit._trace._trace_module_map = None
_C._jit_pass_onnx_clear_scope_records()
@ -1388,7 +1388,7 @@ def _get_module_attributes(module):
return attrs
def _trigger_symbolic_function_registration():
def _trigger_symbolic_function_registration() -> None:
"""Trigger the registration of symbolic functions for all supported opsets."""
from torch.onnx._internal.torchscript_exporter import ( # noqa: F401
@ -1599,7 +1599,7 @@ def _export(
return torch_out
def _apply_friendly_debug_names(graph, params):
def _apply_friendly_debug_names(graph, params) -> None:
for n in graph.nodes():
for v in n.inputs():
old_name = v.debugName()
@ -1611,8 +1611,8 @@ def _apply_friendly_debug_names(graph, params):
params[new_name] = params.pop(old_name)
def _set_input_and_output_names(graph, input_names, output_names):
def set_names(node_list, name_list, descriptor):
def _set_input_and_output_names(graph, input_names, output_names) -> None:
def set_names(node_list, name_list, descriptor) -> None:
if name_list is None:
return
if len(name_list) > len(node_list):
@ -1681,7 +1681,7 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
def _should_aten_fallback(
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
) -> bool:
# For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
@ -1822,7 +1822,7 @@ def _run_symbolic_function(
raise
def _verify_custom_op_name(symbolic_name: str):
def _verify_custom_op_name(symbolic_name: str) -> None:
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
raise errors.OnnxExporterError(
f"Failed to register operator {symbolic_name}. "
@ -1842,7 +1842,7 @@ def register_custom_op_symbolic(
symbolic_name: str,
symbolic_fn: Callable,
opset_version: int,
):
) -> None:
"""Registers a symbolic function for a custom operator.
When the user registers symbolic for custom/contrib ops,
@ -1868,7 +1868,7 @@ def register_custom_op_symbolic(
registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn)
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int) -> None:
"""Unregisters ``symbolic_name``.
See "Custom Operators" in the module documentation for an example usage.
@ -1886,7 +1886,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
registration.registry.unregister(symbolic_name, opset_version)
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) -> None:
"""Ensures dynamic axes argument is follows the expected format."""
if len(dynamic_axes) == 0:
return

View File

@ -209,7 +209,7 @@ def _compare_onnx_pytorch_outputs_in_np(
onnx_outs: _OutputsType,
pt_outs: _OutputsType,
options: VerificationOptions,
):
) -> None:
assert len(onnx_outs) == len(pt_outs), (
f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
)
@ -261,7 +261,7 @@ def _compare_onnx_pytorch_outputs(
onnx_outs: _OutputsType,
pt_outs: Any,
options: VerificationOptions,
):
) -> None:
"""
Compare ONNX and PyTorch outputs.
@ -383,7 +383,7 @@ def _compare_onnx_pytorch_model(
input_kwargs: _InputKwargsType | None,
additional_test_inputs: Sequence[_InputArgsType] | None,
options: VerificationOptions,
):
) -> None:
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
Args:
@ -401,7 +401,7 @@ def _compare_onnx_pytorch_model(
"""
onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs) -> None:
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
# TODO: remove this and treat mutating model separately. See #77679
pt_model_copy = _try_clone_model(pt_model)
@ -443,7 +443,7 @@ def verify(
use_external_data: bool = False,
additional_test_inputs: Sequence[_InputArgsType] | None = None,
options: VerificationOptions | None = None,
):
) -> None:
"""Verify model export to ONNX against original PyTorch model.
.. deprecated:: 2.7

View File

@ -30,7 +30,7 @@ class UnsupportedOperatorError(OnnxExporterError):
# NOTE: This is legacy and is only used by the torchscript exporter
# Clean up when the torchscript exporter is removed
def __init__(self, name: str, version: int, supported_version: int | None):
def __init__(self, name: str, version: int, supported_version: int | None) -> None:
if supported_version is not None:
msg = (
f"Exporting the operator '{name}' to ONNX opset version {version} "
@ -57,7 +57,7 @@ class SymbolicValueError(OnnxExporterError):
# NOTE: This is legacy and is only used by the torchscript exporter
# Clean up when the torchscript exporter is removed
def __init__(self, msg: str, value: _C.Value):
def __init__(self, msg: str, value: _C.Value) -> None:
message = (
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "

View File

@ -299,7 +299,7 @@ class _ConfigEntry:
hide: bool = False
alias: Optional[str] = None
def __init__(self, config: _Config):
def __init__(self, config: _Config) -> None:
self.default = config.default
self.value_type = (
config.value_type if config.value_type is not None else type(self.default)
@ -792,7 +792,7 @@ class SubConfigProxy:
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
"""
def __init__(self, config: object, prefix: str):
def __init__(self, config: object, prefix: str) -> None:
# `super().__setattr__` to bypass custom `__setattr__`
super().__setattr__("_config", config)
super().__setattr__("_prefix", prefix)

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import math
import operator
from typing import Union
from typing import NoReturn, Union
import sympy
@ -139,7 +139,7 @@ class ReferenceAnalysis:
return FloorDiv(a, b)
@staticmethod
def truncdiv(a, b):
def truncdiv(a, b) -> NoReturn:
raise NotImplementedError("TODO: truncdiv")
@staticmethod
@ -257,11 +257,11 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def exp(x):
def exp(x) -> NoReturn:
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
def log(x) -> NoReturn:
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
@ -448,7 +448,7 @@ class TensorReferenceAnalysis:
return _to_dtype(x, dtype)
@staticmethod
def mod(x, y):
def mod(x, y) -> NoReturn:
# TODO: https://github.com/pytorch/pytorch/pull/133654
raise NotImplementedError(
"no C-style modulus operation available from frontend atm"
@ -484,7 +484,7 @@ class TensorReferenceAnalysis:
return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
@staticmethod
def truncdiv(a, b):
def truncdiv(a, b) -> NoReturn:
raise NotImplementedError(
"no C-style truncdiv operation available from frontend atm"
)
@ -575,7 +575,7 @@ class TensorReferenceAnalysis:
return torch.ops.aten.round.default(a)
@staticmethod
def round_decimal(a, b):
def round_decimal(a, b) -> NoReturn:
raise NotImplementedError(
"round decimal doesn't support Tensor second argument atm"
)

View File

@ -188,7 +188,7 @@ class Timer:
env: Optional[str] = None,
num_threads: int = 1,
language: Union[Language, str] = Language.PYTHON,
):
) -> None:
if not isinstance(stmt, str):
raise ValueError("Currently only a `str` stmt is supported.")

View File

@ -14,6 +14,7 @@ import torch.fx.traceback as fx_traceback
from torch.utils._pytree import tree_map
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
from torch.utils._python_dispatch import TorchDispatchMode
from typing import NoReturn
__all__ = [
"checkpoint",
@ -107,7 +108,7 @@ class DefaultDeviceType:
_default_device_type = "cuda"
@staticmethod
def set_device_type(device: str = "cuda"):
def set_device_type(device: str = "cuda") -> None:
"""
Set the default device type for checkpointing.
@ -130,7 +131,7 @@ class DefaultDeviceType:
def _infer_device_type(*args):
device_types = []
def add_device_types(arg):
def add_device_types(arg) -> None:
nonlocal device_types
if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
device_types.append(arg.device.type)
@ -166,7 +167,7 @@ def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
# the conditionals short-circuit.
fwd_device_ids = []
def add_device_ids(arg):
def add_device_ids(arg) -> None:
nonlocal fwd_device_ids
if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
fwd_device_ids.append(arg.get_device())
@ -601,7 +602,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar
return run_function(end + 1, len(functions) - 1, functions)(input)
def _internal_assert(cond):
def _internal_assert(cond) -> None:
if not cond:
raise AssertionError(
"Something went unexpectedly wrong in activation checkpoint. "
@ -779,7 +780,7 @@ class _Handle:
class _Holder:
def __init__(self):
def __init__(self) -> None:
self.handles: Dict[int, Optional[_Handle]] = {}
@ -817,12 +818,12 @@ class _NoopSaveInputs(torch.autograd.Function):
ctx.save_for_backward(*tensors)
@staticmethod
def backward(ctx, *grad_outputs):
def backward(ctx, *grad_outputs) -> NoReturn:
raise AssertionError("Did not expect to backward on this graph")
class _CheckpointFrame:
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn) -> None:
self.recompute_fn = recompute_fn
self.input_saver = None
self.weak_holders: List[ReferenceType] = []
@ -847,7 +848,7 @@ class _CheckpointFrame:
self.forward_completed = False
self.ignore_saved_mismatch = False
def check_recomputed_tensors_match(self, gid):
def check_recomputed_tensors_match(self, gid) -> None:
if self.ignore_saved_mismatch:
# TODO: we can probably make this check stricter by checking that
# the metadata of the first tensors still match.
@ -999,7 +1000,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
class CaptureLogs:
def __init__(self):
def __init__(self) -> None:
self.logs = None
self.tbs = None
@ -1016,7 +1017,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
capture_logs_fwd = CaptureLogs()
capture_logs_recompute = CaptureLogs()
def unpack_error_cb(e: CheckpointError):
def unpack_error_cb(e: CheckpointError) -> NoReturn:
def get_str_tb(label, capture_logs):
out = ""
total_len = len(capture_logs.logs)
@ -1071,7 +1072,7 @@ class _StopRecomputationError(Exception):
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int):
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
@ -1132,7 +1133,7 @@ def _run_fn_with_dynamo_disabled(fn, *args, **kwargs):
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, frame):
def __init__(self, frame) -> None:
def pack_hook(x):
# See Rule 4 above
holder = _Holder()
@ -1196,7 +1197,7 @@ def _is_compiling(func, args, kwargs):
class _VersionWrapper:
# Check that cached tensors are not mutated.
def __init__(self, val):
def __init__(self, val) -> None:
self.val: Union[torch.Tensor, Any] = val
self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
@ -1251,7 +1252,7 @@ class SelectiveCheckpointContext:
>>> context_fn=context_fn,
>>> )
"""
def __init__(self, *, is_recompute):
def __init__(self, *, is_recompute) -> None:
self.is_recompute = is_recompute
@ -1301,7 +1302,7 @@ SAC_IGNORED_OPS = {
class _CachingTorchDispatchMode(TorchDispatchMode):
# Used together with _CachedTorchDispatchMode to implement SAC.
def __init__(self, policy_fn, storage):
def __init__(self, policy_fn, storage) -> None:
self.policy_fn = policy_fn
self.storage = storage
@ -1337,7 +1338,7 @@ class _CachingTorchDispatchMode(TorchDispatchMode):
class _CachedTorchDispatchMode(TorchDispatchMode):
# Used together with _CachedTorchDispatchMode to implement SAC.
def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
def __init__(self, policy_fn, storage, allow_cache_entry_mutation) -> None:
self.policy_fn = policy_fn
self.storage = storage
self.allow_cache_entry_mutation = allow_cache_entry_mutation
@ -1542,7 +1543,7 @@ def _checkpoint_without_reentrant_generator(
had_device_in_fwd = True
fwd_devices, fwd_device_states = get_device_states(*args)
def recompute_fn(*inputs):
def recompute_fn(*inputs) -> None:
kwargs, *args = inputs
# This will be called later during recomputation. This wrapping enables
# the necessary global state to be captured.

View File

@ -4,20 +4,22 @@ r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
This logic is shared in both single- and multi-processing data loading.
"""
from typing import NoReturn
class _BaseDatasetFetcher:
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None:
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
def fetch(self, possibly_batched_index) -> NoReturn:
raise NotImplementedError
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None:
super().__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False

View File

@ -17,7 +17,7 @@ import queue
import threading
import warnings
from collections.abc import Callable
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self
import torch
@ -108,7 +108,7 @@ def _get_distributed_settings():
return 1, 0
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id) -> None:
global_worker_id = worker_id
info = torch.utils.data.get_worker_info()
if info is None:
@ -436,7 +436,7 @@ class DataLoader(Generic[_T_co]):
return self.__multiprocessing_context
@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):
def multiprocessing_context(self, multiprocessing_context) -> None:
if multiprocessing_context is not None:
if self.num_workers > 0:
if isinstance(multiprocessing_context, str):
@ -468,7 +468,7 @@ class DataLoader(Generic[_T_co]):
self.__multiprocessing_context = multiprocessing_context
def __setattr__(self, attr, val):
def __setattr__(self, attr, val) -> None:
if self.__initialized and attr in (
"batch_size",
"batch_sampler",
@ -546,7 +546,7 @@ class DataLoader(Generic[_T_co]):
else:
return len(self._index_sampler)
def check_worker_number_rationality(self):
def check_worker_number_rationality(self) -> None:
# This function check whether the dataloader's worker number is rational based on
# current system's resource. Current rule is that if the number of workers this
# Dataloader will create is bigger than the number of logical cpus that is allowed to
@ -714,7 +714,7 @@ class _BaseDataLoaderIter:
def __iter__(self) -> Self:
return self
def _reset(self, loader, first_iter=False):
def _reset(self, loader, first_iter=False) -> None:
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
@ -729,7 +729,7 @@ class _BaseDataLoaderIter:
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
def _next_data(self) -> NoReturn:
raise NotImplementedError
def __next__(self) -> Any:
@ -770,7 +770,7 @@ class _BaseDataLoaderIter:
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
def __init__(self, loader) -> None:
super().__init__(loader)
if self._timeout != 0:
raise AssertionError("_SingleProcessDataLoaderIter requires timeout == 0")
@ -1113,7 +1113,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# processing indices already in `index_queue` if we are already shutting
# down.
def __init__(self, loader):
def __init__(self, loader) -> None:
super().__init__(loader)
self._prefetch_factor = loader.prefetch_factor
@ -1235,7 +1235,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
self._worker_pids_set = True
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
def _reset(self, loader, first_iter=False) -> None:
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
@ -1529,7 +1529,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
self._rcvd_idx += 1
return self._process_data(data, worker_id)
def _try_put_index(self):
def _try_put_index(self) -> None:
max_tasks = self._prefetch_factor * self._num_workers
if self._tasks_outstanding >= max_tasks:
raise AssertionError(
@ -1568,7 +1568,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
data.reraise()
return data
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
def _mark_worker_as_unavailable(self, worker_id, shutdown=False) -> None:
# Mark a worker as having finished its work e.g., due to
# exhausting an `IterableDataset`. This should be used only when this
# `_MultiProcessingDataLoaderIter` is going to continue running.
@ -1604,7 +1604,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
"_workers_done_event state does not match shutdown flag"
)
def _shutdown_workers(self):
def _shutdown_workers(self) -> None:
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
@ -1678,12 +1678,12 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
@staticmethod
def _clean_up_worker(w):
def _clean_up_worker(w) -> None:
try:
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
finally:
if w.is_alive():
w.terminate()
def __del__(self):
def __del__(self) -> None:
self._shutdown_workers()

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Optional
from typing import Any, NoReturn, Optional
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
@ -33,7 +33,7 @@ __all__ = [
]
def disable_capture():
def disable_capture() -> None:
CaptureControl.disabled = True
@ -42,7 +42,7 @@ class CaptureControl:
class DataFrameTracedOps(DFIterDataPipe):
def __init__(self, source_datapipe, output_var):
def __init__(self, source_datapipe, output_var) -> None:
self.source_datapipe = source_datapipe
self.output_var = output_var
@ -72,10 +72,10 @@ UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sha
class Capture:
# TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
def __init__(self, schema_df=None):
def __init__(self, schema_df=None) -> None:
self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
def __str__(self):
def __str__(self) -> str:
return self._ops_str()
def _ops_str(self):
@ -113,7 +113,7 @@ class Capture:
def __getitem__(self, key):
return CaptureGetItem(self, key, ctx=self.ctx)
def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
@ -147,7 +147,7 @@ class Capture:
# pyrefly: ignore [bad-argument-type]
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
def apply_ops_2(self, dataframe):
def apply_ops_2(self, dataframe) -> None:
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore [unsupported-operation]
self.ctx["variables"][0].calculated_value = dataframe
@ -190,7 +190,7 @@ class Capture:
class CaptureF(Capture):
def __init__(self, ctx=None, **kwargs):
def __init__(self, ctx=None, **kwargs) -> None:
if ctx is None:
self.ctx = {"operations": [], "variables": []}
else:
@ -199,7 +199,7 @@ class CaptureF(Capture):
class CaptureA(CaptureF):
def __str__(self):
def __str__(self) -> str:
return f"{self.kwargs['name']}"
def execute(self):
@ -208,7 +208,7 @@ class CaptureA(CaptureF):
class CaptureLikeMock:
def __init__(self, name):
def __init__(self, name) -> None:
import unittest.mock as mock
# TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead.
@ -227,7 +227,7 @@ class CaptureLikeMock:
class CaptureCall(Capture):
def __init__(self, callable, ctx=None, **kwargs):
def __init__(self, callable, ctx=None, **kwargs) -> None:
if ctx is None:
self.ctx = {"operations": [], "variables": []}
else:
@ -235,7 +235,7 @@ class CaptureCall(Capture):
self.kwargs = kwargs
self.callable = callable
def __str__(self):
def __str__(self) -> str:
return "{callable}({args},{kwargs})".format(
callable=self.callable, **self.kwargs
)
@ -253,12 +253,12 @@ class CaptureCall(Capture):
class CaptureVariableAssign(CaptureF):
def __str__(self):
def __str__(self) -> str:
variable = self.kwargs["variable"]
value = self.kwargs["value"]
return f"{variable} = {value}"
def execute(self):
def execute(self) -> None:
self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
@ -266,7 +266,7 @@ class CaptureVariable(Capture):
# TODO(VitalyFedyunin): This should be atomic and thread safe
names_idx = 0
def __init__(self, value, ctx):
def __init__(self, value, ctx) -> None:
if CaptureControl.disabled:
raise RuntimeError("Attempting to create capture variable with capture off")
self.ctx = ctx
@ -275,7 +275,7 @@ class CaptureVariable(Capture):
CaptureVariable.names_idx += 1
self.ctx["variables"].append(self)
def __str__(self):
def __str__(self) -> str:
return self.name
def execute(self):
@ -292,12 +292,12 @@ class CaptureVariable(Capture):
class CaptureGetItem(Capture):
def __init__(self, left, key, ctx):
def __init__(self, left, key, ctx) -> None:
self.ctx = ctx
self.left = left
self.key = key
def __str__(self):
def __str__(self) -> str:
return f"{self.left}[{get_val(self.key)}]"
def execute(self):
@ -306,28 +306,28 @@ class CaptureGetItem(Capture):
class CaptureSetItem(Capture):
def __init__(self, left, key, value, ctx):
def __init__(self, left, key, value, ctx) -> None:
self.ctx = ctx
self.left = left
self.key = key
self.value = value
def __str__(self):
def __str__(self) -> str:
return f"{self.left}[{get_val(self.key)}] = {self.value}"
def execute(self):
def execute(self) -> None:
left = self.left.execute()
value = self.value.execute()
left[self.key] = value
class CaptureAdd(Capture):
def __init__(self, left, right, ctx):
def __init__(self, left, right, ctx) -> None:
self.ctx = ctx
self.left = left
self.right = right
def __str__(self):
def __str__(self) -> str:
return f"{self.left} + {self.right}"
def execute(self):
@ -335,12 +335,12 @@ class CaptureAdd(Capture):
class CaptureMul(Capture):
def __init__(self, left, right, ctx):
def __init__(self, left, right, ctx) -> None:
self.ctx = ctx
self.left = left
self.right = right
def __str__(self):
def __str__(self) -> str:
return f"{self.left} * {self.right}"
def execute(self):
@ -348,12 +348,12 @@ class CaptureMul(Capture):
class CaptureSub(Capture):
def __init__(self, left, right, ctx):
def __init__(self, left, right, ctx) -> None:
self.ctx = ctx
self.left = left
self.right = right
def __str__(self):
def __str__(self) -> str:
return f"{self.left} - {self.right}"
def execute(self):
@ -361,12 +361,12 @@ class CaptureSub(Capture):
class CaptureGetAttr(Capture):
def __init__(self, src, name, ctx):
def __init__(self, src, name, ctx) -> None:
self.ctx = ctx
self.src = src
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"{self.src}.{self.name}"
def execute(self):
@ -384,7 +384,7 @@ def get_val(capture):
class CaptureInitial(CaptureVariable):
def __init__(self, schema_df=None):
def __init__(self, schema_df=None) -> None:
# pyrefly: ignore [bad-assignment]
new_ctx: dict[str, list[Any]] = {
"operations": [],
@ -441,7 +441,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
def filter(self, *args, **kwargs):
return self._dataframes_filter(*args, **kwargs)
def collate(self, *args, **kwargs):
def collate(self, *args, **kwargs) -> NoReturn:
raise RuntimeError("Can't collate unbatched DataFrames stream")
def __getattr__(self, attrname): # ?
@ -458,13 +458,13 @@ class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: i
# TODO(VitalyFedyunin): Must implement all special functions of datapipes
def set_shuffle_settings(self, *args, **kwargs):
def set_shuffle_settings(self, *args, **kwargs) -> None:
pass
def is_shardable(self):
def is_shardable(self) -> bool:
return False
def __init__(self, source_datapipe, schema_df=None):
def __init__(self, source_datapipe, schema_df=None) -> None:
self.source_datapipe = source_datapipe
if schema_df is None:
schema_df = next(iter(self.source_datapipe))

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from collections import defaultdict
from collections.abc import Callable, Iterator, Sized
from typing import Any, Optional, TypeVar
from typing import Any, NoReturn, Optional, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
@ -18,7 +18,7 @@ __all__ = [
_T_co = TypeVar("_T_co", covariant=True)
def __getattr__(name: str):
def __getattr__(name: str) -> NoReturn:
raise AttributeError(f"module {__name__} has no attribute {name}")
@ -110,7 +110,7 @@ class UnBatcherIterDataPipe(IterDataPipe):
[0, 1, 2, 3, 4, 5, 6]
"""
def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1):
def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1) -> None:
self.datapipe = datapipe
self.unbatch_level = unbatch_level
@ -202,7 +202,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
group_size: Optional[int] = None,
guaranteed_group_size: Optional[int] = None,
drop_remaining: bool = False,
):
) -> None:
_check_unpickable_fn(group_key_fn)
# pyrefly: ignore [invalid-type-var]
self.datapipe = datapipe
@ -322,5 +322,5 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
self.curr_buffer_size = 0
self.buffer_elements = defaultdict(list)
def __del__(self):
def __del__(self) -> None:
self.buffer_elements.clear()

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
from collections.abc import Sized
from enum import IntEnum
from typing import NoReturn
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
@ -24,7 +25,7 @@ class _ShardingIterDataPipe(IterDataPipe):
num_of_instances: int,
instance_id: int,
sharding_group: SHARDING_PRIORITIES,
):
) -> NoReturn:
raise NotImplementedError
@ -40,7 +41,9 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
source_datapipe: Iterable DataPipe that will be sharded
"""
def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
def __init__(
self, source_datapipe: IterDataPipe, sharding_group_filter=None
) -> None:
self.source_datapipe = source_datapipe
self.sharding_group_filter = sharding_group_filter
self.groups: dict[int, tuple[int, int]] = {}
@ -68,7 +71,7 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
self.groups[sharding_group] = (num_of_instances, instance_id)
self._update_num_of_instances()
def _update_num_of_instances(self):
def _update_num_of_instances(self) -> None:
sorted_sharding_groups = [
self.groups[key]
for key in sorted(self.groups.keys())
@ -89,7 +92,7 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
if i % self.num_of_instances == self.instance_id:
yield item
def __len__(self):
def __len__(self) -> int:
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe) // self.num_of_instances + (
1

View File

@ -6,7 +6,7 @@ import os
import warnings
from collections.abc import Callable, Iterable
from io import IOBase
from typing import Any, Optional, Union
from typing import Any, NoReturn, Optional, Union
from torch.utils._import_utils import dill_available
@ -25,7 +25,9 @@ __all__ = [
DILL_AVAILABLE = dill_available()
def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
def validate_input_col(
fn: Callable, input_col: Optional[Union[int, tuple, list]]
) -> None:
"""
Check that function used in a callable datapipe works with the input column.
@ -131,7 +133,7 @@ def _is_local_fn(fn):
return False
def _check_unpickable_fn(fn: Callable):
def _check_unpickable_fn(fn: Callable) -> None:
"""
Check function is pickable or not.
@ -186,7 +188,7 @@ def get_file_pathnames_from_root(
non_deterministic: bool = False,
) -> Iterable[str]:
# print out an error message and raise the error out
def onerror(err: OSError):
def onerror(err: OSError) -> NoReturn:
warnings.warn(err.filename + " : " + err.strerror, stacklevel=2)
raise err
@ -235,7 +237,7 @@ def get_file_binaries_from_pathnames(
yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding))
def validate_pathname_binary_tuple(data: tuple[str, IOBase]):
def validate_pathname_binary_tuple(data: tuple[str, IOBase]) -> None:
if not isinstance(data, tuple):
raise TypeError(
f"pathname binary data should be tuple type, but it is type {type(data)}"
@ -326,7 +328,7 @@ class StreamWrapper:
session_streams: dict[Any, int] = {}
debug_unclosed_streams: bool = False
def __init__(self, file_obj, parent_stream=None, name=None):
def __init__(self, file_obj, parent_stream=None, name=None) -> None:
self.file_obj = file_obj
self.child_counter = 0
self.parent_stream = parent_stream
@ -344,7 +346,7 @@ class StreamWrapper:
StreamWrapper.session_streams[self] = 1
@classmethod
def close_streams(cls, v, depth=0):
def close_streams(cls, v, depth=0) -> None:
"""Traverse structure and attempts to close all found StreamWrappers on best effort basis."""
if depth > 10:
return
@ -363,7 +365,7 @@ class StreamWrapper:
file_obj = self.__dict__["file_obj"]
return getattr(file_obj, name)
def close(self, *args, **kwargs):
def close(self, *args, **kwargs) -> None:
if self.closed:
return
if StreamWrapper.debug_unclosed_streams:
@ -381,7 +383,7 @@ class StreamWrapper:
pass
self.closed = True
def autoclose(self):
def autoclose(self) -> None:
"""Automatically close stream when all child streams are closed or if there are none."""
self.close_on_last_child = True
if self.child_counter == 0:
@ -392,7 +394,7 @@ class StreamWrapper:
attrs += dir(self.file_obj)
return list(set(attrs))
def __del__(self):
def __del__(self) -> None:
if not self.closed:
self.close()
@ -402,7 +404,7 @@ class StreamWrapper:
def __next__(self):
return next(self.file_obj)
def __repr__(self):
def __repr__(self) -> str:
if self.name is None:
return f"StreamWrapper<{self.file_obj!r}>"
else: