mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 15:45:15 +08:00
Compare commits
5 Commits
ciflow/mps
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 19e52556fa | |||
| 1d43f171d6 | |||
| 910471526d | |||
| edd611f3b0 | |||
| aded2ebb90 |
@ -207,9 +207,9 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
GCC_VERSION=13
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.2
|
||||
NINJA_VERSION=1.9.0
|
||||
|
||||
@ -9,7 +9,7 @@ set -xe
|
||||
|
||||
function install_ubuntu() {
|
||||
. /etc/os-release
|
||||
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -35,25 +35,24 @@ function install_ubuntu() {
|
||||
# The xpu-smi packages
|
||||
apt-get install -y flex bison xpu-smi
|
||||
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Compute and Media Runtimes
|
||||
# Compute and Media Runtimes
|
||||
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
|
||||
apt-get install -y \
|
||||
intel-opencl-icd intel-level-zero-gpu level-zero \
|
||||
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
|
||||
else # rolling driver
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
else # jammy
|
||||
apt-get install -y \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
fi
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
|
||||
# Install Intel Support Packages
|
||||
apt-get install -y ${XPU_PACKAGES}
|
||||
@ -66,7 +65,7 @@ function install_ubuntu() {
|
||||
function install_rhel() {
|
||||
. /etc/os-release
|
||||
if [[ "${ID}" == "rhel" ]]; then
|
||||
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
echo "RHEL version ${VERSION_ID} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -147,7 +146,7 @@ function install_sles() {
|
||||
XPU_DRIVER_VERSION=""
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Use GPU driver LTS releases
|
||||
XPU_DRIVER_VERSION="/lts/2350"
|
||||
XPU_DRIVER_VERSION="/lts/2523"
|
||||
fi
|
||||
|
||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
||||
|
||||
4
.github/workflows/docker-builds.yml
vendored
4
.github/workflows/docker-builds.yml
vendored
@ -68,8 +68,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
|
||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -342,16 +342,16 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||
|
||||
20
.github/workflows/xpu.yml
vendored
20
.github/workflows/xpu.yml
vendored
@ -47,15 +47,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -74,17 +74,17 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-test:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-test:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: linux-jammy-xpu-n-py3_10-build
|
||||
needs: linux-noble-xpu-n-py3_10-build
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-n-1-build:
|
||||
|
||||
@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
|
||||
// supported in conv.
|
||||
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
||||
if (groups > 1 && weight_zero_points.numel() > 1)
|
||||
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
|
||||
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
bool src_need_zp = (act_zero_point != 0);
|
||||
|
||||
@ -1941,6 +1941,7 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_XPU_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
torch_compile_options(${test_name})
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
||||
@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
|
||||
# with an integer argument starting at 0, until __getitem__ raises IndexError
|
||||
ret = variables.UserFunctionVariable(
|
||||
polyfills.builtins.iter_
|
||||
polyfills.builtins.iter_ # type: ignore[arg-type]
|
||||
).call_function(tx, [obj, *args], {})
|
||||
|
||||
if args:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -590,7 +590,7 @@ class FilterVariable(IteratorVariable):
|
||||
else:
|
||||
res = self.fn.call_function(tx, [item], {})
|
||||
pred_res = variables.UserFunctionVariable(
|
||||
polyfills.predicate
|
||||
polyfills.predicate # type: ignore[arg-type]
|
||||
).call_function(tx, [res], {})
|
||||
if pred_res.as_python_constant():
|
||||
return item
|
||||
|
||||
@ -1498,6 +1498,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
variables.UserDefinedClassVariable(self.tuple_cls),
|
||||
)
|
||||
elif isinstance(method, staticmethod):
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
return UserFunctionVariable(method.__func__)
|
||||
elif inspect.isfunction(method):
|
||||
return UserMethodVariable(method, self)
|
||||
|
||||
@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
)
|
||||
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
|
||||
name_to_arg_map = bind_args_cached(
|
||||
self.value, tx, self.source, args, kwargs
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
self.value,
|
||||
tx,
|
||||
self.source,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
backends = name_to_arg_map["backends"].as_python_constant()
|
||||
set_priority = name_to_arg_map["set_priority"].as_python_constant()
|
||||
@ -1429,7 +1434,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
packed_input_vt = TupleVariable.build(
|
||||
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
|
||||
)
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
|
||||
tx, [packed_input_vt], {}
|
||||
)
|
||||
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
|
||||
|
||||
@ -279,7 +279,7 @@ def _hide_source_ranges() -> Iterator[None]:
|
||||
torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def enable_onednn_fusion(enabled: bool):
|
||||
def enable_onednn_fusion(enabled: bool) -> None:
|
||||
"""Enable or disables onednn JIT fusion based on the parameter `enabled`."""
|
||||
torch._C._jit_set_llga_enabled(enabled)
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ def _get_builtin_table():
|
||||
return _builtin_table
|
||||
_builtin_table = {}
|
||||
|
||||
def register_all(mod):
|
||||
def register_all(mod) -> None:
|
||||
for name in dir(mod):
|
||||
v = getattr(mod, name)
|
||||
if (
|
||||
@ -196,7 +196,7 @@ def _get_builtin_table():
|
||||
return _builtin_table
|
||||
|
||||
|
||||
def _register_builtin(fn, op):
|
||||
def _register_builtin(fn, op) -> None:
|
||||
_get_builtin_table()[id(fn)] = op
|
||||
|
||||
|
||||
|
||||
@ -116,7 +116,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
||||
|
||||
return True
|
||||
|
||||
def visit_Assign(self, node):
|
||||
def visit_Assign(self, node) -> None:
|
||||
"""Store assignment state when assigning to a Call Node.
|
||||
|
||||
If we're visiting a Call Node (the right-hand side of an
|
||||
@ -139,7 +139,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
||||
self.generic_visit(node)
|
||||
self.visiting_class_level_ann = False
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
def visit_AnnAssign(self, node) -> None:
|
||||
"""Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
|
||||
|
||||
It checks if it conforms to our attribute annotation rules."""
|
||||
@ -194,7 +194,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def visit_Call(self, node):
|
||||
def visit_Call(self, node) -> None:
|
||||
"""Determine if a Call node is 'torch.jit.annotate' in __init__.
|
||||
|
||||
Visit a Call node in an ``nn.Module``'s ``__init__``
|
||||
|
||||
@ -3,7 +3,7 @@ import torch
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
|
||||
def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
|
||||
def _register_decomposition(op: OpOverload, graph: torch._C.Graph) -> None:
|
||||
assert not isinstance(op, OpOverloadPacket), (
|
||||
f"Must pass specific op overload, not overload packet, found {op}"
|
||||
)
|
||||
|
||||
@ -20,7 +20,7 @@ _T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def check_decomposition_has_type_annotations(f):
|
||||
def check_decomposition_has_type_annotations(f) -> None:
|
||||
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
||||
sig = inspect.signature(f)
|
||||
for param in sig.parameters.values():
|
||||
|
||||
@ -125,7 +125,7 @@ def freeze(
|
||||
|
||||
def run_frozen_optimizations(
|
||||
mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Run a series of optimizations looking for patterns that occur in frozen graphs.
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ def fuser(name):
|
||||
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
||||
|
||||
|
||||
def _get_differentiable_graph_node(node, diff_node):
|
||||
def _get_differentiable_graph_node(node, diff_node) -> None:
|
||||
if node.kind() == "prim::DifferentiableGraph":
|
||||
diff_node.append(node)
|
||||
else:
|
||||
|
||||
@ -9,7 +9,7 @@ class _InsertPoint:
|
||||
self,
|
||||
insert_point_graph: torch._C.Graph,
|
||||
insert_point: Union[torch._C.Node, torch._C.Block],
|
||||
):
|
||||
) -> None:
|
||||
self.insert_point = insert_point
|
||||
self.g = insert_point_graph
|
||||
self.guard = None
|
||||
|
||||
@ -85,7 +85,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
class JitTypeTraceStoreLogger(CallTraceStoreLogger):
|
||||
"""A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore."""
|
||||
|
||||
def __init__(self, store: CallTraceStore):
|
||||
def __init__(self, store: CallTraceStore) -> None:
|
||||
super().__init__(store)
|
||||
|
||||
def log(self, trace: CallTrace) -> None:
|
||||
@ -100,7 +100,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
# value is list of all CallTrace
|
||||
self.trace_records: dict[str, list] = defaultdict(list)
|
||||
|
||||
def add(self, traces: Iterable[CallTrace]):
|
||||
def add(self, traces: Iterable[CallTrace]) -> None:
|
||||
for t in traces:
|
||||
qualified_name = get_qualified_name(t.func)
|
||||
self.trace_records[qualified_name].append(t)
|
||||
@ -145,7 +145,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
return self.consolidate_types(qualified_name)
|
||||
|
||||
class JitTypeTraceConfig(monkeytype.config.Config):
|
||||
def __init__(self, s: JitTypeTraceStore):
|
||||
def __init__(self, s: JitTypeTraceStore) -> None:
|
||||
super().__init__()
|
||||
self.s = s
|
||||
|
||||
|
||||
@ -152,7 +152,7 @@ def _get_valid_constant(attr, v, owner_type):
|
||||
|
||||
|
||||
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None:
|
||||
super().__init__(source, filename, file_lineno, leading_whitespace_len)
|
||||
|
||||
|
||||
@ -454,7 +454,7 @@ concrete_type_store = ConcreteTypeStore()
|
||||
|
||||
def create_methods_and_properties_from_stubs(
|
||||
concrete_type, method_stubs, property_stubs
|
||||
):
|
||||
) -> None:
|
||||
method_defs = [m.def_ for m in method_stubs]
|
||||
method_rcbs = [m.resolution_callback for m in method_stubs]
|
||||
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
|
||||
@ -467,7 +467,7 @@ def create_methods_and_properties_from_stubs(
|
||||
)
|
||||
|
||||
|
||||
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
|
||||
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) -> None:
|
||||
hook_defs = [h.def_ for h in hook_stubs]
|
||||
hook_rcbs = [h.resolution_callback for h in hook_stubs]
|
||||
|
||||
@ -571,7 +571,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
|
||||
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
|
||||
ignored_properties = jit_ignored_properties(nn_module)
|
||||
|
||||
def init_fn(script_module):
|
||||
def init_fn(script_module) -> None:
|
||||
# Initialize the ScriptModule:
|
||||
# 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
|
||||
for name in concrete_type.get_attributes():
|
||||
@ -725,7 +725,7 @@ def script_model_defines_attr(script_model, attr):
|
||||
return script_attr != default_attr
|
||||
|
||||
|
||||
def add_python_attr_to_scripted_model(script_model, orig, attr):
|
||||
def add_python_attr_to_scripted_model(script_model, orig, attr) -> None:
|
||||
if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
|
||||
setattr(script_model, attr, getattr(orig, attr))
|
||||
|
||||
@ -777,7 +777,7 @@ def get_overload_name_mapping(overload_info):
|
||||
return overload_name_mappings
|
||||
|
||||
|
||||
def _check_no_signature(func):
|
||||
def _check_no_signature(func) -> None:
|
||||
signature = torch.jit.annotations.get_signature(
|
||||
func, None, fake_range(), inspect.ismethod(func)
|
||||
)
|
||||
@ -807,7 +807,7 @@ def make_stubs_for_overloads(overload_info):
|
||||
return overload_stubs
|
||||
|
||||
|
||||
def check_module_initialized(mod):
|
||||
def check_module_initialized(mod) -> None:
|
||||
assert isinstance(mod, torch.nn.Module)
|
||||
if not hasattr(mod, "_parameters"):
|
||||
raise RuntimeError(
|
||||
@ -1002,7 +1002,7 @@ def wrap_cpp_class(cpp_class):
|
||||
def wrap_cpp_module(cpp_module):
|
||||
"""Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
|
||||
|
||||
def init_fn(script_module):
|
||||
def init_fn(script_module) -> None:
|
||||
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
|
||||
setattr(script_module, name, wrap_cpp_module(cpp_module))
|
||||
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
|
||||
@ -1037,7 +1037,7 @@ def lazy_bind(concrete_type, unbound_method):
|
||||
"""
|
||||
|
||||
def lazy_binding_method(cpp_module, *args):
|
||||
def init_fn(script_module):
|
||||
def init_fn(script_module) -> None:
|
||||
orig_class = concrete_type.py_class
|
||||
|
||||
# Copy @ignored/@unused methods from the original module to the new one.
|
||||
|
||||
@ -18,7 +18,7 @@ from torch.jit._recursive import wrap_cpp_module
|
||||
from torch.serialization import validate_cuda_device
|
||||
|
||||
|
||||
def save(m, f, _extra_files=None):
|
||||
def save(m, f, _extra_files=None) -> None:
|
||||
r"""
|
||||
Save an offline version of this module for use in a separate process.
|
||||
|
||||
@ -213,7 +213,7 @@ def jit_module_from_flatbuffer(f):
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
|
||||
|
||||
|
||||
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
||||
def save_jit_module_to_flatbuffer(m, f, _extra_files=None) -> None:
|
||||
r"""
|
||||
Save an offline version of this module for use in a separate process.
|
||||
|
||||
|
||||
@ -41,18 +41,18 @@ class EnabledProxy:
|
||||
return False
|
||||
raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.")
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
|
||||
_enabled = EnabledProxy()
|
||||
|
||||
|
||||
def disable():
|
||||
def disable() -> None:
|
||||
_enabled.enabled = False
|
||||
|
||||
|
||||
def enable():
|
||||
def enable() -> None:
|
||||
_enabled.enabled = True
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ _script_classes: dict[type[Any], type[Any]] = {}
|
||||
_name_to_pyclass: dict[str, type[Any]] = {}
|
||||
|
||||
|
||||
def _add_script_class(python_class, script_class):
|
||||
def _add_script_class(python_class, script_class) -> None:
|
||||
_script_classes[python_class] = script_class
|
||||
_name_to_pyclass[script_class.qualified_name()] = python_class
|
||||
|
||||
@ -83,7 +83,7 @@ def _get_python_class(qualified_name):
|
||||
return _name_to_pyclass.get(qualified_name)
|
||||
|
||||
|
||||
def _clear_class_state():
|
||||
def _clear_class_state() -> None:
|
||||
_script_classes.clear()
|
||||
_name_to_pyclass.clear()
|
||||
|
||||
@ -108,7 +108,7 @@ def _try_get_jit_cached_overloads(key):
|
||||
return None
|
||||
|
||||
|
||||
def _set_jit_overload_cache(key, compiled_fns):
|
||||
def _set_jit_overload_cache(key, compiled_fns) -> None:
|
||||
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
|
||||
|
||||
|
||||
@ -122,7 +122,7 @@ def _try_get_jit_cached_function(key):
|
||||
return None
|
||||
|
||||
|
||||
def _set_jit_function_cache(key, value):
|
||||
def _set_jit_function_cache(key, value) -> None:
|
||||
# only free functions currently supported
|
||||
assert isinstance(value, torch.jit.ScriptFunction)
|
||||
_jit_caching_layer[key] = value.qualified_name
|
||||
|
||||
@ -68,7 +68,7 @@ from torch._ops import OpOverloadPacket
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(self, name, members):
|
||||
def __init__(self, name, members) -> None:
|
||||
self.name = name
|
||||
self.members = members
|
||||
|
||||
@ -95,7 +95,7 @@ class EvalEnv:
|
||||
"Await": _Await,
|
||||
}
|
||||
|
||||
def __init__(self, rcb):
|
||||
def __init__(self, rcb) -> None:
|
||||
self.rcb = rcb
|
||||
if torch.distributed.rpc.is_available():
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
@ -178,7 +178,7 @@ def get_param_names(fn, n_args):
|
||||
return [str(i) for i in range(n_args)]
|
||||
|
||||
|
||||
def check_fn(fn, loc):
|
||||
def check_fn(fn, loc) -> None:
|
||||
# Make sure the function definition is not a class instantiation
|
||||
try:
|
||||
source = dedent("".join(get_source_lines_and_file(fn)[0]))
|
||||
@ -368,7 +368,7 @@ def get_enum_value_type(e: type[enum.Enum], loc):
|
||||
return res
|
||||
|
||||
|
||||
def is_tensor(ann):
|
||||
def is_tensor(ann) -> bool:
|
||||
if issubclass(ann, torch.Tensor):
|
||||
return True
|
||||
|
||||
@ -397,7 +397,7 @@ def is_tensor(ann):
|
||||
return False
|
||||
|
||||
|
||||
def _fake_rcb(inp):
|
||||
def _fake_rcb(inp) -> None:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ pretty_node_names.update(
|
||||
|
||||
|
||||
class FrontendError(Exception):
|
||||
def __init__(self, source_range, msg):
|
||||
def __init__(self, source_range, msg) -> None:
|
||||
self.source_range = source_range
|
||||
self.msg = msg
|
||||
|
||||
@ -155,7 +155,7 @@ class FrontendError(Exception):
|
||||
# call stack when the FrontendError was raised
|
||||
self.error_report = torch._C.ErrorReport(self.source_range)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.msg + self.error_report.what().lstrip()
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ class NotSupportedError(FrontendError):
|
||||
|
||||
|
||||
class UnsupportedNodeError(NotSupportedError):
|
||||
def __init__(self, ctx, offending_node, reason=""):
|
||||
def __init__(self, ctx, offending_node, reason="") -> None:
|
||||
# If we don't have a specific token, we default to length of 1
|
||||
node_type = type(offending_node)
|
||||
range_len = len(node_start_tokens.get(node_type, " "))
|
||||
@ -229,7 +229,7 @@ def get_class_properties(cls, self_name):
|
||||
def get_class_assigns(ctx, cls_ast):
|
||||
assigns = []
|
||||
|
||||
def maybe_build_assign(builder, entry):
|
||||
def maybe_build_assign(builder, entry) -> None:
|
||||
nonlocal assigns
|
||||
try:
|
||||
assigns.append(builder(ctx, entry))
|
||||
@ -385,7 +385,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
|
||||
|
||||
# TODO: more robust handling of recognizing ignore context manager
|
||||
def is_torch_jit_ignore_context_manager(stmt):
|
||||
def is_torch_jit_ignore_context_manager(stmt) -> bool:
|
||||
# checks if the statement is torch.jit.ignore context manager
|
||||
if isinstance(stmt.items[0].context_expr, ast.Call):
|
||||
# extract torch part
|
||||
@ -535,7 +535,7 @@ def build_ignore_context_manager(ctx, stmt):
|
||||
outputs.append(OutputType(var_name, var_ann))
|
||||
return inputs, outputs
|
||||
|
||||
def create_unique_name_ext(ctx, stmt):
|
||||
def create_unique_name_ext(ctx, stmt) -> str:
|
||||
# extension will be based on the full path filename plus
|
||||
# the line number of original context manager
|
||||
fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename)
|
||||
|
||||
@ -56,7 +56,7 @@ def _load_for_lite_interpreter(f, map_location=None):
|
||||
|
||||
|
||||
class LiteScriptModule:
|
||||
def __init__(self, cpp_module):
|
||||
def __init__(self, cpp_module) -> None:
|
||||
self._c = cpp_module
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ def _emit_schema(mod, name, schema, arg_start=0, padding=4):
|
||||
|
||||
|
||||
def _get_tensor_ops():
|
||||
def is_tensor_method(schema):
|
||||
def is_tensor_method(schema) -> bool:
|
||||
if len(schema.arguments) == 0:
|
||||
return False
|
||||
self = schema.arguments[0]
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any
|
||||
import torch.jit
|
||||
|
||||
|
||||
def execWrapper(code, glob, loc):
|
||||
def execWrapper(code, glob, loc) -> None:
|
||||
exec(code, glob, loc)
|
||||
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ def _ge(lhs: Any, rhs: Any) -> bool:
|
||||
|
||||
|
||||
class NestedIntNode:
|
||||
def __init__(self, t_id: int, coeff: int):
|
||||
def __init__(self, t_id: int, coeff: int) -> None:
|
||||
self.t_id = t_id
|
||||
self.coeff = coeff
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ class NestedTensor(torch.Tensor):
|
||||
|
||||
return r
|
||||
|
||||
def __init__(self, values, offsets, *, lengths=None, **kwargs):
|
||||
def __init__(self, values, offsets, *, lengths=None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._values = values
|
||||
@ -243,7 +243,7 @@ class NestedTensor(torch.Tensor):
|
||||
self._values, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
|
||||
@ -400,7 +400,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
||||
# Handle flatten() here because it's CompositeImplicit.
|
||||
if func.__name__ == "flatten":
|
||||
|
||||
def _flatten_sig(input, start_dim=0, end_dim=-1):
|
||||
def _flatten_sig(input, start_dim=0, end_dim=-1) -> None:
|
||||
pass
|
||||
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
@ -466,7 +466,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
||||
# Handle nested-specific input validation for CompositeImplicit rms_norm
|
||||
if func.__name__ == "rms_norm":
|
||||
|
||||
def _rms_norm_sig(input, normalized_shape, weight=None, eps=None):
|
||||
def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None:
|
||||
pass
|
||||
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
@ -532,7 +532,7 @@ def prim_layout_default(func, *args, **kwargs):
|
||||
[torch.ops.aten.size.default],
|
||||
"self: jt_all",
|
||||
)
|
||||
def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||
def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None:
|
||||
if func is torch.ops.aten.size.default:
|
||||
raise RuntimeError(
|
||||
"NestedTensor does not support directly calling torch.ops.aten.size; "
|
||||
@ -1138,7 +1138,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
lengths = inp.lengths()
|
||||
ragged_idx = inp._ragged_idx
|
||||
|
||||
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
|
||||
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None) -> None:
|
||||
# This torch._check are needed for torch.compile
|
||||
# symbolic shapes processing.
|
||||
# offsets and lengths are symbolic variables during compilation,
|
||||
@ -2615,7 +2615,7 @@ def _nested_select_backward_default(func, *args, **kwargs):
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
|
||||
def record_stream_default(func, *args, **kwargs):
|
||||
def record_stream_default(func, *args, **kwargs) -> None:
|
||||
inp = args[0]
|
||||
stream = args[1]
|
||||
# ensure all components live until stream computation completes
|
||||
|
||||
@ -31,7 +31,7 @@ def _validate_sdpa_input(
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
):
|
||||
) -> None:
|
||||
if (
|
||||
not isinstance(query, NestedTensor)
|
||||
or not isinstance(key, NestedTensor)
|
||||
@ -364,7 +364,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, in
|
||||
return cumulative_seqlen, max_seqlen, n_elem
|
||||
|
||||
|
||||
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
|
||||
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool:
|
||||
# This function checks if a nested tensor is valid for
|
||||
# use with the flash-attention and efficient_attention kernels without
|
||||
# needing to call contiguous on the nested tensor input.
|
||||
|
||||
@ -537,7 +537,7 @@ class OpRecorder(evaluator.Evaluator):
|
||||
|
||||
def __init__(
|
||||
self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value]
|
||||
):
|
||||
) -> None:
|
||||
self.nodes: list[ir.Node] = []
|
||||
self.opset = opset
|
||||
self.functions: dict[
|
||||
|
||||
@ -92,7 +92,7 @@ class CaptureStrategy(abc.ABC):
|
||||
dump: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
timestamp: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the strategy.
|
||||
|
||||
Args:
|
||||
|
||||
@ -109,7 +109,7 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
|
||||
|
||||
|
||||
class TorchTensor(ir.Tensor):
|
||||
def __init__(self, tensor: torch.Tensor, name: str | None = None):
|
||||
def __init__(self, tensor: torch.Tensor, name: str | None = None) -> None:
|
||||
# Pass the tensor as the raw data to ir.Tensor's constructor
|
||||
if tensor.dtype == torch.float4_e2m1fn_x2:
|
||||
# Change the shape to the unpacked shape
|
||||
|
||||
@ -211,7 +211,7 @@ class ONNXProgram:
|
||||
|
||||
def __init__(
|
||||
self, model: ir.Model, exported_program: torch.export.ExportedProgram | None
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the ONNX program with the specified model and exported program.
|
||||
Args:
|
||||
model: The ONNX model.
|
||||
@ -327,7 +327,7 @@ ONNXProgram(
|
||||
include_initializers: bool = True,
|
||||
keep_initializers_as_inputs: bool = False,
|
||||
external_data: bool | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Save the ONNX model to the specified destination.
|
||||
|
||||
When ``external_data`` is ``True`` or the model is larger than 2GB,
|
||||
|
||||
@ -149,7 +149,7 @@ def create_torch_export_error_report(
|
||||
*,
|
||||
export_status: ExportStatus,
|
||||
profile_result: str | None,
|
||||
):
|
||||
) -> None:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("# PyTorch ONNX Conversion Error Report\n\n")
|
||||
f.write(_format_export_status(export_status))
|
||||
@ -175,7 +175,7 @@ def create_onnx_export_report(
|
||||
model: ir.Model | None = None,
|
||||
registry: _registration.ONNXRegistry | None = None,
|
||||
verification_result: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("# PyTorch ONNX Conversion Report\n\n")
|
||||
f.write(_format_export_status(export_status))
|
||||
|
||||
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# A special value to indicate that the default value is not specified
|
||||
class _Empty:
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "_EMPTY_DEFAULT"
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class SymbolicTensor(ir.Value):
|
||||
type: ir.TypeProtocol | None = None,
|
||||
doc_string: str | None = None,
|
||||
const_value: ir.TensorProtocol | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=name,
|
||||
shape=shape,
|
||||
|
||||
@ -66,7 +66,7 @@ def _patch_difflib_sequence_matcher_init():
|
||||
"""
|
||||
original_init = difflib.SequenceMatcher.__init__
|
||||
|
||||
def patched_init(self, isjunk=None, a="", b="", autojunk=True):
|
||||
def patched_init(self, isjunk=None, a="", b="", autojunk=True) -> None:
|
||||
original_init(self, isjunk, a, b, autojunk=False)
|
||||
|
||||
difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment]
|
||||
@ -192,7 +192,7 @@ class Transform(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the transform.
|
||||
|
||||
Args:
|
||||
|
||||
@ -63,7 +63,7 @@ class TypePromotionSnapshot:
|
||||
class TypePromotionRule(abc.ABC):
|
||||
"""Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'."""
|
||||
|
||||
def __init__(self, namespace: str, op_name: str):
|
||||
def __init__(self, namespace: str, op_name: str) -> None:
|
||||
self.namespace = namespace
|
||||
self.op_name = op_name
|
||||
|
||||
@ -74,7 +74,7 @@ class TypePromotionRule(abc.ABC):
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __repr__(self): ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
@ -128,7 +128,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
|
||||
promote_args_positions: Sequence[int],
|
||||
promote_kwargs_names: Sequence[str],
|
||||
promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
):
|
||||
) -> None:
|
||||
"""Constructs a TypePromotionRule for elementwise operators.
|
||||
|
||||
Args:
|
||||
@ -143,7 +143,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
|
||||
self.promote_kwargs_names = promote_kwargs_names
|
||||
self.promotion_kind = promotion_kind
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', "
|
||||
f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
|
||||
@ -216,7 +216,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
|
||||
Rule depends on the value of the `rounding_mode` argument.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"aten",
|
||||
"div",
|
||||
@ -252,7 +252,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
||||
namespace: str,
|
||||
op_name: str,
|
||||
promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND,
|
||||
):
|
||||
) -> None:
|
||||
"""Constructs a TypePromotionRule for reduction operators.
|
||||
|
||||
Args:
|
||||
@ -263,7 +263,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
||||
super().__init__(namespace, op_name)
|
||||
self.promotion_kind = promotion_kind
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
|
||||
|
||||
# pyrefly: ignore [bad-override]
|
||||
@ -311,7 +311,7 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
|
||||
The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool.
|
||||
"""
|
||||
|
||||
def __init__(self, op_name: str):
|
||||
def __init__(self, op_name: str) -> None:
|
||||
super().__init__(
|
||||
"aten",
|
||||
op_name,
|
||||
@ -1205,7 +1205,7 @@ class ElementwiseTypePromotionRuleSetGenerator:
|
||||
class TypePromotionTable:
|
||||
"""Type promotion table for torch.ops."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._rule_table = {}
|
||||
for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET:
|
||||
self.add_rule(rule)
|
||||
@ -1262,7 +1262,7 @@ class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
|
||||
op overload for a given op overload packet for different set of args and kwargs.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.traced_ops = []
|
||||
|
||||
@ -1331,7 +1331,7 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
type_promotion_table: TypePromotionTable,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.type_promotion_table = type_promotion_table
|
||||
|
||||
@ -1603,7 +1603,7 @@ class InsertTypePromotion(_pass.Transform):
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
type_promotion_table: TypePromotionTable | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.interpreter = _TypePromotionInterpreter(
|
||||
module, type_promotion_table or TypePromotionTable()
|
||||
|
||||
@ -810,7 +810,7 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
|
||||
|
||||
@_onnx_symbolic("aten::cumsum")
|
||||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
|
||||
def cumsum(g: jit_utils.GraphContext, input, dim, dtype) -> None:
|
||||
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
|
||||
|
||||
|
||||
@ -3332,7 +3332,9 @@ def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
|
||||
|
||||
@_onnx_symbolic("aten::_unique2")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
|
||||
def _unique2(
|
||||
g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts
|
||||
) -> None:
|
||||
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
||||
|
||||
|
||||
@ -6289,7 +6291,7 @@ def broadcast_tensors(g: jit_utils.GraphContext, self):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::is_pinned")
|
||||
def is_pinned(g: jit_utils.GraphContext, self, device=None):
|
||||
def is_pinned(g: jit_utils.GraphContext, self, device=None) -> None:
|
||||
# Unused by ONNX.
|
||||
return None
|
||||
|
||||
@ -6357,7 +6359,7 @@ def prim_layout(g: jit_utils.GraphContext, self):
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::ListConstruct")
|
||||
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||||
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@ -6374,12 +6376,12 @@ def prim_list_unpack(
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::TupleConstruct")
|
||||
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||||
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::Uninitialized")
|
||||
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||||
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -571,7 +571,7 @@ def export(
|
||||
return None
|
||||
|
||||
|
||||
def _is_constant_tensor_list(node):
|
||||
def _is_constant_tensor_list(node) -> bool | None:
|
||||
if node.kind() != "prim::Constant":
|
||||
return False
|
||||
output_type = node.output().type()
|
||||
@ -585,7 +585,7 @@ def _is_constant_tensor_list(node):
|
||||
# get generated in constant prop. So we split them back into prim::ListConstructs
|
||||
|
||||
|
||||
def _split_tensor_list_constants(g, block):
|
||||
def _split_tensor_list_constants(g, block) -> None:
|
||||
for node in block.nodes():
|
||||
for subblock in node.blocks():
|
||||
_split_tensor_list_constants(g, subblock)
|
||||
@ -722,7 +722,7 @@ def _optimize_graph(
|
||||
return graph
|
||||
|
||||
|
||||
def warn_on_static_input_change(input_states):
|
||||
def warn_on_static_input_change(input_states) -> None:
|
||||
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
|
||||
|
||||
We accept dictionaries and strings as ONNX inputs, but they should be only for
|
||||
@ -932,7 +932,7 @@ def _get_param_count_list(method_graph, args_params):
|
||||
return param_count_list
|
||||
|
||||
|
||||
def _check_flatten_did_not_remove(original, jit_flattened):
|
||||
def _check_flatten_did_not_remove(original, jit_flattened) -> None:
|
||||
"""torch.jit._flatten removes None. Check if it did so in this case."""
|
||||
|
||||
def flatten(x):
|
||||
@ -1286,13 +1286,13 @@ def _setup_trace_module_map(
|
||||
model: torch.nn.Module | torch.jit.ScriptModule,
|
||||
export_modules_as_functions: bool | Collection[type[torch.nn.Module]],
|
||||
) -> set[str]:
|
||||
def __register_attribute_hook():
|
||||
def __register_attribute_hook() -> None:
|
||||
attr_name = "_onnx_attrs"
|
||||
|
||||
def _track_module_attributes_forward_pre_hook(module, input):
|
||||
def _track_module_attributes_forward_pre_hook(module, input) -> None:
|
||||
setattr(module, attr_name, _get_module_attributes(module))
|
||||
|
||||
def _track_module_attributes_forward_hook(module, input, output):
|
||||
def _track_module_attributes_forward_hook(module, input, output) -> None:
|
||||
tracing_state = _C._get_tracing_state()
|
||||
if not tracing_state:
|
||||
return
|
||||
@ -1359,7 +1359,7 @@ def _setup_trace_module_map(
|
||||
return module_typenames
|
||||
|
||||
|
||||
def _reset_trace_module_map():
|
||||
def _reset_trace_module_map() -> None:
|
||||
torch.jit._trace._trace_module_map = None
|
||||
_C._jit_pass_onnx_clear_scope_records()
|
||||
|
||||
@ -1388,7 +1388,7 @@ def _get_module_attributes(module):
|
||||
return attrs
|
||||
|
||||
|
||||
def _trigger_symbolic_function_registration():
|
||||
def _trigger_symbolic_function_registration() -> None:
|
||||
"""Trigger the registration of symbolic functions for all supported opsets."""
|
||||
|
||||
from torch.onnx._internal.torchscript_exporter import ( # noqa: F401
|
||||
@ -1599,7 +1599,7 @@ def _export(
|
||||
return torch_out
|
||||
|
||||
|
||||
def _apply_friendly_debug_names(graph, params):
|
||||
def _apply_friendly_debug_names(graph, params) -> None:
|
||||
for n in graph.nodes():
|
||||
for v in n.inputs():
|
||||
old_name = v.debugName()
|
||||
@ -1611,8 +1611,8 @@ def _apply_friendly_debug_names(graph, params):
|
||||
params[new_name] = params.pop(old_name)
|
||||
|
||||
|
||||
def _set_input_and_output_names(graph, input_names, output_names):
|
||||
def set_names(node_list, name_list, descriptor):
|
||||
def _set_input_and_output_names(graph, input_names, output_names) -> None:
|
||||
def set_names(node_list, name_list, descriptor) -> None:
|
||||
if name_list is None:
|
||||
return
|
||||
if len(name_list) > len(node_list):
|
||||
@ -1681,7 +1681,7 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
|
||||
|
||||
def _should_aten_fallback(
|
||||
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
|
||||
):
|
||||
) -> bool:
|
||||
# For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
|
||||
# an aten::ATen operator is created regardless of symbolics existence
|
||||
|
||||
@ -1822,7 +1822,7 @@ def _run_symbolic_function(
|
||||
raise
|
||||
|
||||
|
||||
def _verify_custom_op_name(symbolic_name: str):
|
||||
def _verify_custom_op_name(symbolic_name: str) -> None:
|
||||
if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
|
||||
raise errors.OnnxExporterError(
|
||||
f"Failed to register operator {symbolic_name}. "
|
||||
@ -1842,7 +1842,7 @@ def register_custom_op_symbolic(
|
||||
symbolic_name: str,
|
||||
symbolic_fn: Callable,
|
||||
opset_version: int,
|
||||
):
|
||||
) -> None:
|
||||
"""Registers a symbolic function for a custom operator.
|
||||
|
||||
When the user registers symbolic for custom/contrib ops,
|
||||
@ -1868,7 +1868,7 @@ def register_custom_op_symbolic(
|
||||
registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn)
|
||||
|
||||
|
||||
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int) -> None:
|
||||
"""Unregisters ``symbolic_name``.
|
||||
|
||||
See "Custom Operators" in the module documentation for an example usage.
|
||||
@ -1886,7 +1886,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
|
||||
registration.registry.unregister(symbolic_name, opset_version)
|
||||
|
||||
|
||||
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
||||
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) -> None:
|
||||
"""Ensures dynamic axes argument is follows the expected format."""
|
||||
if len(dynamic_axes) == 0:
|
||||
return
|
||||
|
||||
@ -209,7 +209,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: _OutputsType,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
) -> None:
|
||||
assert len(onnx_outs) == len(pt_outs), (
|
||||
f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
)
|
||||
@ -261,7 +261,7 @@ def _compare_onnx_pytorch_outputs(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: Any,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Compare ONNX and PyTorch outputs.
|
||||
|
||||
@ -383,7 +383,7 @@ def _compare_onnx_pytorch_model(
|
||||
input_kwargs: _InputKwargsType | None,
|
||||
additional_test_inputs: Sequence[_InputArgsType] | None,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
) -> None:
|
||||
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
|
||||
|
||||
Args:
|
||||
@ -401,7 +401,7 @@ def _compare_onnx_pytorch_model(
|
||||
"""
|
||||
onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
|
||||
|
||||
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
|
||||
def compare_onnx_pytorch_model_with_input(input_args, input_kwargs) -> None:
|
||||
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
|
||||
# TODO: remove this and treat mutating model separately. See #77679
|
||||
pt_model_copy = _try_clone_model(pt_model)
|
||||
@ -443,7 +443,7 @@ def verify(
|
||||
use_external_data: bool = False,
|
||||
additional_test_inputs: Sequence[_InputArgsType] | None = None,
|
||||
options: VerificationOptions | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Verify model export to ONNX against original PyTorch model.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
|
||||
@ -30,7 +30,7 @@ class UnsupportedOperatorError(OnnxExporterError):
|
||||
|
||||
# NOTE: This is legacy and is only used by the torchscript exporter
|
||||
# Clean up when the torchscript exporter is removed
|
||||
def __init__(self, name: str, version: int, supported_version: int | None):
|
||||
def __init__(self, name: str, version: int, supported_version: int | None) -> None:
|
||||
if supported_version is not None:
|
||||
msg = (
|
||||
f"Exporting the operator '{name}' to ONNX opset version {version} "
|
||||
@ -57,7 +57,7 @@ class SymbolicValueError(OnnxExporterError):
|
||||
|
||||
# NOTE: This is legacy and is only used by the torchscript exporter
|
||||
# Clean up when the torchscript exporter is removed
|
||||
def __init__(self, msg: str, value: _C.Value):
|
||||
def __init__(self, msg: str, value: _C.Value) -> None:
|
||||
message = (
|
||||
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
|
||||
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
|
||||
|
||||
@ -299,7 +299,7 @@ class _ConfigEntry:
|
||||
hide: bool = False
|
||||
alias: Optional[str] = None
|
||||
|
||||
def __init__(self, config: _Config):
|
||||
def __init__(self, config: _Config) -> None:
|
||||
self.default = config.default
|
||||
self.value_type = (
|
||||
config.value_type if config.value_type is not None else type(self.default)
|
||||
@ -792,7 +792,7 @@ class SubConfigProxy:
|
||||
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
|
||||
"""
|
||||
|
||||
def __init__(self, config: object, prefix: str):
|
||||
def __init__(self, config: object, prefix: str) -> None:
|
||||
# `super().__setattr__` to bypass custom `__setattr__`
|
||||
super().__setattr__("_config", config)
|
||||
super().__setattr__("_prefix", prefix)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import operator
|
||||
from typing import Union
|
||||
from typing import NoReturn, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -139,7 +139,7 @@ class ReferenceAnalysis:
|
||||
return FloorDiv(a, b)
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
def truncdiv(a, b) -> NoReturn:
|
||||
raise NotImplementedError("TODO: truncdiv")
|
||||
|
||||
@staticmethod
|
||||
@ -257,11 +257,11 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
||||
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
def exp(x) -> NoReturn:
|
||||
raise AssertionError("exp is not valid shape sympy expr")
|
||||
|
||||
@staticmethod
|
||||
def log(x):
|
||||
def log(x) -> NoReturn:
|
||||
raise AssertionError("log is not valid shape sympy expr")
|
||||
|
||||
@staticmethod
|
||||
@ -448,7 +448,7 @@ class TensorReferenceAnalysis:
|
||||
return _to_dtype(x, dtype)
|
||||
|
||||
@staticmethod
|
||||
def mod(x, y):
|
||||
def mod(x, y) -> NoReturn:
|
||||
# TODO: https://github.com/pytorch/pytorch/pull/133654
|
||||
raise NotImplementedError(
|
||||
"no C-style modulus operation available from frontend atm"
|
||||
@ -484,7 +484,7 @@ class TensorReferenceAnalysis:
|
||||
return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
def truncdiv(a, b) -> NoReturn:
|
||||
raise NotImplementedError(
|
||||
"no C-style truncdiv operation available from frontend atm"
|
||||
)
|
||||
@ -575,7 +575,7 @@ class TensorReferenceAnalysis:
|
||||
return torch.ops.aten.round.default(a)
|
||||
|
||||
@staticmethod
|
||||
def round_decimal(a, b):
|
||||
def round_decimal(a, b) -> NoReturn:
|
||||
raise NotImplementedError(
|
||||
"round decimal doesn't support Tensor second argument atm"
|
||||
)
|
||||
|
||||
@ -188,7 +188,7 @@ class Timer:
|
||||
env: Optional[str] = None,
|
||||
num_threads: int = 1,
|
||||
language: Union[Language, str] = Language.PYTHON,
|
||||
):
|
||||
) -> None:
|
||||
if not isinstance(stmt, str):
|
||||
raise ValueError("Currently only a `str` stmt is supported.")
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ import torch.fx.traceback as fx_traceback
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from typing import NoReturn
|
||||
|
||||
__all__ = [
|
||||
"checkpoint",
|
||||
@ -107,7 +108,7 @@ class DefaultDeviceType:
|
||||
_default_device_type = "cuda"
|
||||
|
||||
@staticmethod
|
||||
def set_device_type(device: str = "cuda"):
|
||||
def set_device_type(device: str = "cuda") -> None:
|
||||
"""
|
||||
Set the default device type for checkpointing.
|
||||
|
||||
@ -130,7 +131,7 @@ class DefaultDeviceType:
|
||||
def _infer_device_type(*args):
|
||||
device_types = []
|
||||
|
||||
def add_device_types(arg):
|
||||
def add_device_types(arg) -> None:
|
||||
nonlocal device_types
|
||||
if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
|
||||
device_types.append(arg.device.type)
|
||||
@ -166,7 +167,7 @@ def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
||||
# the conditionals short-circuit.
|
||||
fwd_device_ids = []
|
||||
|
||||
def add_device_ids(arg):
|
||||
def add_device_ids(arg) -> None:
|
||||
nonlocal fwd_device_ids
|
||||
if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
|
||||
fwd_device_ids.append(arg.get_device())
|
||||
@ -601,7 +602,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar
|
||||
return run_function(end + 1, len(functions) - 1, functions)(input)
|
||||
|
||||
|
||||
def _internal_assert(cond):
|
||||
def _internal_assert(cond) -> None:
|
||||
if not cond:
|
||||
raise AssertionError(
|
||||
"Something went unexpectedly wrong in activation checkpoint. "
|
||||
@ -779,7 +780,7 @@ class _Handle:
|
||||
|
||||
|
||||
class _Holder:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.handles: Dict[int, Optional[_Handle]] = {}
|
||||
|
||||
|
||||
@ -817,12 +818,12 @@ class _NoopSaveInputs(torch.autograd.Function):
|
||||
ctx.save_for_backward(*tensors)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
def backward(ctx, *grad_outputs) -> NoReturn:
|
||||
raise AssertionError("Did not expect to backward on this graph")
|
||||
|
||||
|
||||
class _CheckpointFrame:
|
||||
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
|
||||
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn) -> None:
|
||||
self.recompute_fn = recompute_fn
|
||||
self.input_saver = None
|
||||
self.weak_holders: List[ReferenceType] = []
|
||||
@ -847,7 +848,7 @@ class _CheckpointFrame:
|
||||
self.forward_completed = False
|
||||
self.ignore_saved_mismatch = False
|
||||
|
||||
def check_recomputed_tensors_match(self, gid):
|
||||
def check_recomputed_tensors_match(self, gid) -> None:
|
||||
if self.ignore_saved_mismatch:
|
||||
# TODO: we can probably make this check stricter by checking that
|
||||
# the metadata of the first tensors still match.
|
||||
@ -999,7 +1000,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
|
||||
|
||||
class CaptureLogs:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.logs = None
|
||||
self.tbs = None
|
||||
|
||||
@ -1016,7 +1017,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
capture_logs_fwd = CaptureLogs()
|
||||
capture_logs_recompute = CaptureLogs()
|
||||
|
||||
def unpack_error_cb(e: CheckpointError):
|
||||
def unpack_error_cb(e: CheckpointError) -> NoReturn:
|
||||
def get_str_tb(label, capture_logs):
|
||||
out = ""
|
||||
total_len = len(capture_logs.logs)
|
||||
@ -1071,7 +1072,7 @@ class _StopRecomputationError(Exception):
|
||||
|
||||
|
||||
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: int):
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
|
||||
def pack_hook(x):
|
||||
x = x.detach() if x.requires_grad else x
|
||||
target_frame = target_frame_ref()
|
||||
@ -1132,7 +1133,7 @@ def _run_fn_with_dynamo_disabled(fn, *args, **kwargs):
|
||||
|
||||
|
||||
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, frame):
|
||||
def __init__(self, frame) -> None:
|
||||
def pack_hook(x):
|
||||
# See Rule 4 above
|
||||
holder = _Holder()
|
||||
@ -1196,7 +1197,7 @@ def _is_compiling(func, args, kwargs):
|
||||
|
||||
class _VersionWrapper:
|
||||
# Check that cached tensors are not mutated.
|
||||
def __init__(self, val):
|
||||
def __init__(self, val) -> None:
|
||||
self.val: Union[torch.Tensor, Any] = val
|
||||
self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
|
||||
|
||||
@ -1251,7 +1252,7 @@ class SelectiveCheckpointContext:
|
||||
>>> context_fn=context_fn,
|
||||
>>> )
|
||||
"""
|
||||
def __init__(self, *, is_recompute):
|
||||
def __init__(self, *, is_recompute) -> None:
|
||||
self.is_recompute = is_recompute
|
||||
|
||||
|
||||
@ -1301,7 +1302,7 @@ SAC_IGNORED_OPS = {
|
||||
|
||||
class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
# Used together with _CachedTorchDispatchMode to implement SAC.
|
||||
def __init__(self, policy_fn, storage):
|
||||
def __init__(self, policy_fn, storage) -> None:
|
||||
self.policy_fn = policy_fn
|
||||
self.storage = storage
|
||||
|
||||
@ -1337,7 +1338,7 @@ class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
|
||||
class _CachedTorchDispatchMode(TorchDispatchMode):
|
||||
# Used together with _CachedTorchDispatchMode to implement SAC.
|
||||
def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
|
||||
def __init__(self, policy_fn, storage, allow_cache_entry_mutation) -> None:
|
||||
self.policy_fn = policy_fn
|
||||
self.storage = storage
|
||||
self.allow_cache_entry_mutation = allow_cache_entry_mutation
|
||||
@ -1542,7 +1543,7 @@ def _checkpoint_without_reentrant_generator(
|
||||
had_device_in_fwd = True
|
||||
fwd_devices, fwd_device_states = get_device_states(*args)
|
||||
|
||||
def recompute_fn(*inputs):
|
||||
def recompute_fn(*inputs) -> None:
|
||||
kwargs, *args = inputs
|
||||
# This will be called later during recomputation. This wrapping enables
|
||||
# the necessary global state to be captured.
|
||||
|
||||
@ -4,20 +4,22 @@ r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
|
||||
This logic is shared in both single- and multi-processing data loading.
|
||||
"""
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
class _BaseDatasetFetcher:
|
||||
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
||||
def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None:
|
||||
self.dataset = dataset
|
||||
self.auto_collation = auto_collation
|
||||
self.collate_fn = collate_fn
|
||||
self.drop_last = drop_last
|
||||
|
||||
def fetch(self, possibly_batched_index):
|
||||
def fetch(self, possibly_batched_index) -> NoReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
||||
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
||||
def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None:
|
||||
super().__init__(dataset, auto_collation, collate_fn, drop_last)
|
||||
self.dataset_iter = iter(dataset)
|
||||
self.ended = False
|
||||
|
||||
@ -17,7 +17,7 @@ import queue
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -108,7 +108,7 @@ def _get_distributed_settings():
|
||||
return 1, 0
|
||||
|
||||
|
||||
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
|
||||
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id) -> None:
|
||||
global_worker_id = worker_id
|
||||
info = torch.utils.data.get_worker_info()
|
||||
if info is None:
|
||||
@ -436,7 +436,7 @@ class DataLoader(Generic[_T_co]):
|
||||
return self.__multiprocessing_context
|
||||
|
||||
@multiprocessing_context.setter
|
||||
def multiprocessing_context(self, multiprocessing_context):
|
||||
def multiprocessing_context(self, multiprocessing_context) -> None:
|
||||
if multiprocessing_context is not None:
|
||||
if self.num_workers > 0:
|
||||
if isinstance(multiprocessing_context, str):
|
||||
@ -468,7 +468,7 @@ class DataLoader(Generic[_T_co]):
|
||||
|
||||
self.__multiprocessing_context = multiprocessing_context
|
||||
|
||||
def __setattr__(self, attr, val):
|
||||
def __setattr__(self, attr, val) -> None:
|
||||
if self.__initialized and attr in (
|
||||
"batch_size",
|
||||
"batch_sampler",
|
||||
@ -546,7 +546,7 @@ class DataLoader(Generic[_T_co]):
|
||||
else:
|
||||
return len(self._index_sampler)
|
||||
|
||||
def check_worker_number_rationality(self):
|
||||
def check_worker_number_rationality(self) -> None:
|
||||
# This function check whether the dataloader's worker number is rational based on
|
||||
# current system's resource. Current rule is that if the number of workers this
|
||||
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
||||
@ -714,7 +714,7 @@ class _BaseDataLoaderIter:
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _reset(self, loader, first_iter=False):
|
||||
def _reset(self, loader, first_iter=False) -> None:
|
||||
self._sampler_iter = iter(self._index_sampler)
|
||||
self._num_yielded = 0
|
||||
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
||||
@ -729,7 +729,7 @@ class _BaseDataLoaderIter:
|
||||
def _next_index(self):
|
||||
return next(self._sampler_iter) # may raise StopIteration
|
||||
|
||||
def _next_data(self):
|
||||
def _next_data(self) -> NoReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
def __next__(self) -> Any:
|
||||
@ -770,7 +770,7 @@ class _BaseDataLoaderIter:
|
||||
|
||||
|
||||
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
||||
def __init__(self, loader):
|
||||
def __init__(self, loader) -> None:
|
||||
super().__init__(loader)
|
||||
if self._timeout != 0:
|
||||
raise AssertionError("_SingleProcessDataLoaderIter requires timeout == 0")
|
||||
@ -1113,7 +1113,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
# processing indices already in `index_queue` if we are already shutting
|
||||
# down.
|
||||
|
||||
def __init__(self, loader):
|
||||
def __init__(self, loader) -> None:
|
||||
super().__init__(loader)
|
||||
|
||||
self._prefetch_factor = loader.prefetch_factor
|
||||
@ -1235,7 +1235,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
self._worker_pids_set = True
|
||||
self._reset(loader, first_iter=True)
|
||||
|
||||
def _reset(self, loader, first_iter=False):
|
||||
def _reset(self, loader, first_iter=False) -> None:
|
||||
super()._reset(loader, first_iter)
|
||||
self._send_idx = 0 # idx of the next task to be sent to workers
|
||||
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
||||
@ -1529,7 +1529,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
self._rcvd_idx += 1
|
||||
return self._process_data(data, worker_id)
|
||||
|
||||
def _try_put_index(self):
|
||||
def _try_put_index(self) -> None:
|
||||
max_tasks = self._prefetch_factor * self._num_workers
|
||||
if self._tasks_outstanding >= max_tasks:
|
||||
raise AssertionError(
|
||||
@ -1568,7 +1568,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
data.reraise()
|
||||
return data
|
||||
|
||||
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
|
||||
def _mark_worker_as_unavailable(self, worker_id, shutdown=False) -> None:
|
||||
# Mark a worker as having finished its work e.g., due to
|
||||
# exhausting an `IterableDataset`. This should be used only when this
|
||||
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
||||
@ -1604,7 +1604,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
"_workers_done_event state does not match shutdown flag"
|
||||
)
|
||||
|
||||
def _shutdown_workers(self):
|
||||
def _shutdown_workers(self) -> None:
|
||||
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
||||
# the logic of this function.
|
||||
@ -1678,12 +1678,12 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
|
||||
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
||||
@staticmethod
|
||||
def _clean_up_worker(w):
|
||||
def _clean_up_worker(w) -> None:
|
||||
try:
|
||||
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||||
finally:
|
||||
if w.is_alive():
|
||||
w.terminate()
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self._shutdown_workers()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Optional
|
||||
from typing import Any, NoReturn, Optional
|
||||
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
|
||||
@ -33,7 +33,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def disable_capture():
|
||||
def disable_capture() -> None:
|
||||
CaptureControl.disabled = True
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class CaptureControl:
|
||||
|
||||
|
||||
class DataFrameTracedOps(DFIterDataPipe):
|
||||
def __init__(self, source_datapipe, output_var):
|
||||
def __init__(self, source_datapipe, output_var) -> None:
|
||||
self.source_datapipe = source_datapipe
|
||||
self.output_var = output_var
|
||||
|
||||
@ -72,10 +72,10 @@ UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sha
|
||||
class Capture:
|
||||
# TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
|
||||
|
||||
def __init__(self, schema_df=None):
|
||||
def __init__(self, schema_df=None) -> None:
|
||||
self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self._ops_str()
|
||||
|
||||
def _ops_str(self):
|
||||
@ -113,7 +113,7 @@ class Capture:
|
||||
def __getitem__(self, key):
|
||||
return CaptureGetItem(self, key, ctx=self.ctx)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key, value) -> None:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
|
||||
|
||||
@ -147,7 +147,7 @@ class Capture:
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
|
||||
|
||||
def apply_ops_2(self, dataframe):
|
||||
def apply_ops_2(self, dataframe) -> None:
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
@ -190,7 +190,7 @@ class Capture:
|
||||
|
||||
|
||||
class CaptureF(Capture):
|
||||
def __init__(self, ctx=None, **kwargs):
|
||||
def __init__(self, ctx=None, **kwargs) -> None:
|
||||
if ctx is None:
|
||||
self.ctx = {"operations": [], "variables": []}
|
||||
else:
|
||||
@ -199,7 +199,7 @@ class CaptureF(Capture):
|
||||
|
||||
|
||||
class CaptureA(CaptureF):
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.kwargs['name']}"
|
||||
|
||||
def execute(self):
|
||||
@ -208,7 +208,7 @@ class CaptureA(CaptureF):
|
||||
|
||||
|
||||
class CaptureLikeMock:
|
||||
def __init__(self, name):
|
||||
def __init__(self, name) -> None:
|
||||
import unittest.mock as mock
|
||||
|
||||
# TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead.
|
||||
@ -227,7 +227,7 @@ class CaptureLikeMock:
|
||||
|
||||
|
||||
class CaptureCall(Capture):
|
||||
def __init__(self, callable, ctx=None, **kwargs):
|
||||
def __init__(self, callable, ctx=None, **kwargs) -> None:
|
||||
if ctx is None:
|
||||
self.ctx = {"operations": [], "variables": []}
|
||||
else:
|
||||
@ -235,7 +235,7 @@ class CaptureCall(Capture):
|
||||
self.kwargs = kwargs
|
||||
self.callable = callable
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "{callable}({args},{kwargs})".format(
|
||||
callable=self.callable, **self.kwargs
|
||||
)
|
||||
@ -253,12 +253,12 @@ class CaptureCall(Capture):
|
||||
|
||||
|
||||
class CaptureVariableAssign(CaptureF):
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
variable = self.kwargs["variable"]
|
||||
value = self.kwargs["value"]
|
||||
return f"{variable} = {value}"
|
||||
|
||||
def execute(self):
|
||||
def execute(self) -> None:
|
||||
self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
|
||||
|
||||
|
||||
@ -266,7 +266,7 @@ class CaptureVariable(Capture):
|
||||
# TODO(VitalyFedyunin): This should be atomic and thread safe
|
||||
names_idx = 0
|
||||
|
||||
def __init__(self, value, ctx):
|
||||
def __init__(self, value, ctx) -> None:
|
||||
if CaptureControl.disabled:
|
||||
raise RuntimeError("Attempting to create capture variable with capture off")
|
||||
self.ctx = ctx
|
||||
@ -275,7 +275,7 @@ class CaptureVariable(Capture):
|
||||
CaptureVariable.names_idx += 1
|
||||
self.ctx["variables"].append(self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def execute(self):
|
||||
@ -292,12 +292,12 @@ class CaptureVariable(Capture):
|
||||
|
||||
|
||||
class CaptureGetItem(Capture):
|
||||
def __init__(self, left, key, ctx):
|
||||
def __init__(self, left, key, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.left = left
|
||||
self.key = key
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left}[{get_val(self.key)}]"
|
||||
|
||||
def execute(self):
|
||||
@ -306,28 +306,28 @@ class CaptureGetItem(Capture):
|
||||
|
||||
|
||||
class CaptureSetItem(Capture):
|
||||
def __init__(self, left, key, value, ctx):
|
||||
def __init__(self, left, key, value, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.left = left
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left}[{get_val(self.key)}] = {self.value}"
|
||||
|
||||
def execute(self):
|
||||
def execute(self) -> None:
|
||||
left = self.left.execute()
|
||||
value = self.value.execute()
|
||||
left[self.key] = value
|
||||
|
||||
|
||||
class CaptureAdd(Capture):
|
||||
def __init__(self, left, right, ctx):
|
||||
def __init__(self, left, right, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left} + {self.right}"
|
||||
|
||||
def execute(self):
|
||||
@ -335,12 +335,12 @@ class CaptureAdd(Capture):
|
||||
|
||||
|
||||
class CaptureMul(Capture):
|
||||
def __init__(self, left, right, ctx):
|
||||
def __init__(self, left, right, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left} * {self.right}"
|
||||
|
||||
def execute(self):
|
||||
@ -348,12 +348,12 @@ class CaptureMul(Capture):
|
||||
|
||||
|
||||
class CaptureSub(Capture):
|
||||
def __init__(self, left, right, ctx):
|
||||
def __init__(self, left, right, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.left} - {self.right}"
|
||||
|
||||
def execute(self):
|
||||
@ -361,12 +361,12 @@ class CaptureSub(Capture):
|
||||
|
||||
|
||||
class CaptureGetAttr(Capture):
|
||||
def __init__(self, src, name, ctx):
|
||||
def __init__(self, src, name, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.src = src
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.src}.{self.name}"
|
||||
|
||||
def execute(self):
|
||||
@ -384,7 +384,7 @@ def get_val(capture):
|
||||
|
||||
|
||||
class CaptureInitial(CaptureVariable):
|
||||
def __init__(self, schema_df=None):
|
||||
def __init__(self, schema_df=None) -> None:
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
new_ctx: dict[str, list[Any]] = {
|
||||
"operations": [],
|
||||
@ -441,7 +441,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
def filter(self, *args, **kwargs):
|
||||
return self._dataframes_filter(*args, **kwargs)
|
||||
|
||||
def collate(self, *args, **kwargs):
|
||||
def collate(self, *args, **kwargs) -> NoReturn:
|
||||
raise RuntimeError("Can't collate unbatched DataFrames stream")
|
||||
|
||||
def __getattr__(self, attrname): # ?
|
||||
@ -458,13 +458,13 @@ class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: i
|
||||
|
||||
# TODO(VitalyFedyunin): Must implement all special functions of datapipes
|
||||
|
||||
def set_shuffle_settings(self, *args, **kwargs):
|
||||
def set_shuffle_settings(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def is_shardable(self):
|
||||
def is_shardable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(self, source_datapipe, schema_df=None):
|
||||
def __init__(self, source_datapipe, schema_df=None) -> None:
|
||||
self.source_datapipe = source_datapipe
|
||||
if schema_df is None:
|
||||
schema_df = next(iter(self.source_datapipe))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator, Sized
|
||||
from typing import Any, Optional, TypeVar
|
||||
from typing import Any, NoReturn, Optional, TypeVar
|
||||
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
|
||||
@ -18,7 +18,7 @@ __all__ = [
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
def __getattr__(name: str) -> NoReturn:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ class UnBatcherIterDataPipe(IterDataPipe):
|
||||
[0, 1, 2, 3, 4, 5, 6]
|
||||
"""
|
||||
|
||||
def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1):
|
||||
def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1) -> None:
|
||||
self.datapipe = datapipe
|
||||
self.unbatch_level = unbatch_level
|
||||
|
||||
@ -202,7 +202,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||
group_size: Optional[int] = None,
|
||||
guaranteed_group_size: Optional[int] = None,
|
||||
drop_remaining: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
_check_unpickable_fn(group_key_fn)
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.datapipe = datapipe
|
||||
@ -322,5 +322,5 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||
self.curr_buffer_size = 0
|
||||
self.buffer_elements = defaultdict(list)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.buffer_elements.clear()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Sized
|
||||
from enum import IntEnum
|
||||
from typing import NoReturn
|
||||
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
@ -24,7 +25,7 @@ class _ShardingIterDataPipe(IterDataPipe):
|
||||
num_of_instances: int,
|
||||
instance_id: int,
|
||||
sharding_group: SHARDING_PRIORITIES,
|
||||
):
|
||||
) -> NoReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -40,7 +41,9 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
|
||||
source_datapipe: Iterable DataPipe that will be sharded
|
||||
"""
|
||||
|
||||
def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
|
||||
def __init__(
|
||||
self, source_datapipe: IterDataPipe, sharding_group_filter=None
|
||||
) -> None:
|
||||
self.source_datapipe = source_datapipe
|
||||
self.sharding_group_filter = sharding_group_filter
|
||||
self.groups: dict[int, tuple[int, int]] = {}
|
||||
@ -68,7 +71,7 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
|
||||
self.groups[sharding_group] = (num_of_instances, instance_id)
|
||||
self._update_num_of_instances()
|
||||
|
||||
def _update_num_of_instances(self):
|
||||
def _update_num_of_instances(self) -> None:
|
||||
sorted_sharding_groups = [
|
||||
self.groups[key]
|
||||
for key in sorted(self.groups.keys())
|
||||
@ -89,7 +92,7 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
|
||||
if i % self.num_of_instances == self.instance_id:
|
||||
yield item
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
if isinstance(self.source_datapipe, Sized):
|
||||
return len(self.source_datapipe) // self.num_of_instances + (
|
||||
1
|
||||
|
||||
@ -6,7 +6,7 @@ import os
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable
|
||||
from io import IOBase
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, NoReturn, Optional, Union
|
||||
|
||||
from torch.utils._import_utils import dill_available
|
||||
|
||||
@ -25,7 +25,9 @@ __all__ = [
|
||||
DILL_AVAILABLE = dill_available()
|
||||
|
||||
|
||||
def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
|
||||
def validate_input_col(
|
||||
fn: Callable, input_col: Optional[Union[int, tuple, list]]
|
||||
) -> None:
|
||||
"""
|
||||
Check that function used in a callable datapipe works with the input column.
|
||||
|
||||
@ -131,7 +133,7 @@ def _is_local_fn(fn):
|
||||
return False
|
||||
|
||||
|
||||
def _check_unpickable_fn(fn: Callable):
|
||||
def _check_unpickable_fn(fn: Callable) -> None:
|
||||
"""
|
||||
Check function is pickable or not.
|
||||
|
||||
@ -186,7 +188,7 @@ def get_file_pathnames_from_root(
|
||||
non_deterministic: bool = False,
|
||||
) -> Iterable[str]:
|
||||
# print out an error message and raise the error out
|
||||
def onerror(err: OSError):
|
||||
def onerror(err: OSError) -> NoReturn:
|
||||
warnings.warn(err.filename + " : " + err.strerror, stacklevel=2)
|
||||
raise err
|
||||
|
||||
@ -235,7 +237,7 @@ def get_file_binaries_from_pathnames(
|
||||
yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding))
|
||||
|
||||
|
||||
def validate_pathname_binary_tuple(data: tuple[str, IOBase]):
|
||||
def validate_pathname_binary_tuple(data: tuple[str, IOBase]) -> None:
|
||||
if not isinstance(data, tuple):
|
||||
raise TypeError(
|
||||
f"pathname binary data should be tuple type, but it is type {type(data)}"
|
||||
@ -326,7 +328,7 @@ class StreamWrapper:
|
||||
session_streams: dict[Any, int] = {}
|
||||
debug_unclosed_streams: bool = False
|
||||
|
||||
def __init__(self, file_obj, parent_stream=None, name=None):
|
||||
def __init__(self, file_obj, parent_stream=None, name=None) -> None:
|
||||
self.file_obj = file_obj
|
||||
self.child_counter = 0
|
||||
self.parent_stream = parent_stream
|
||||
@ -344,7 +346,7 @@ class StreamWrapper:
|
||||
StreamWrapper.session_streams[self] = 1
|
||||
|
||||
@classmethod
|
||||
def close_streams(cls, v, depth=0):
|
||||
def close_streams(cls, v, depth=0) -> None:
|
||||
"""Traverse structure and attempts to close all found StreamWrappers on best effort basis."""
|
||||
if depth > 10:
|
||||
return
|
||||
@ -363,7 +365,7 @@ class StreamWrapper:
|
||||
file_obj = self.__dict__["file_obj"]
|
||||
return getattr(file_obj, name)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
def close(self, *args, **kwargs) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
if StreamWrapper.debug_unclosed_streams:
|
||||
@ -381,7 +383,7 @@ class StreamWrapper:
|
||||
pass
|
||||
self.closed = True
|
||||
|
||||
def autoclose(self):
|
||||
def autoclose(self) -> None:
|
||||
"""Automatically close stream when all child streams are closed or if there are none."""
|
||||
self.close_on_last_child = True
|
||||
if self.child_counter == 0:
|
||||
@ -392,7 +394,7 @@ class StreamWrapper:
|
||||
attrs += dir(self.file_obj)
|
||||
return list(set(attrs))
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if not self.closed:
|
||||
self.close()
|
||||
|
||||
@ -402,7 +404,7 @@ class StreamWrapper:
|
||||
def __next__(self):
|
||||
return next(self.file_obj)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self.name is None:
|
||||
return f"StreamWrapper<{self.file_obj!r}>"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user