mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix docstring issues in torch.utils (#113335)
Fixes #112634 Fixes all the issues listed except in `torch/utils/_pytree.py` as the file no longer exists. ### Error counts |File | Count Before | Count now| |---- | ---- | ---- | |`torch/utils/collect_env.py` | 39 | 25| |`torch/utils/cpp_extension.py` | 51 | 13| |`torch/utils/flop_counter.py` | 25 | 8| |`torch/utils/_foreach_utils.py.py` | 2 | 0| |`torch/utils/_python_dispatch.py.py` | 26 | 25| |`torch/utils/backend_registration.py` | 15 | 4| |`torch/utils/checkpoint.py` | 29 | 21| Pull Request resolved: https://github.com/pytorch/pytorch/pull/113335 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
44367c59b2
commit
5e10dd2c78
@ -90,7 +90,7 @@ def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
|
||||
return [f'"{arg}"' if ' ' in arg else arg for arg in args]
|
||||
|
||||
def _find_cuda_home() -> Optional[str]:
|
||||
r'''Finds the CUDA install path.'''
|
||||
"""Find the CUDA install path."""
|
||||
# Guess #1
|
||||
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
||||
if cuda_home is None:
|
||||
@ -120,7 +120,7 @@ def _find_cuda_home() -> Optional[str]:
|
||||
return cuda_home
|
||||
|
||||
def _find_rocm_home() -> Optional[str]:
|
||||
r'''Finds the ROCm install path.'''
|
||||
"""Find the ROCm install path."""
|
||||
# Guess #1
|
||||
rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
|
||||
if rocm_home is None:
|
||||
@ -144,12 +144,12 @@ def _find_rocm_home() -> Optional[str]:
|
||||
|
||||
|
||||
def _join_rocm_home(*paths) -> str:
|
||||
r'''
|
||||
Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
|
||||
"""
|
||||
Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
|
||||
|
||||
This is basically a lazy way of raising an error for missing $ROCM_HOME
|
||||
only once we need to get any ROCm-specific path.
|
||||
'''
|
||||
"""
|
||||
if ROCM_HOME is None:
|
||||
raise OSError('ROCM_HOME environment variable is not set. '
|
||||
'Please set it to your ROCm install root.')
|
||||
@ -282,8 +282,8 @@ def _maybe_write(filename, new_content):
|
||||
source_file.write(new_content)
|
||||
|
||||
def get_default_build_root() -> str:
|
||||
r'''
|
||||
Returns the path to the root folder under which extensions will built.
|
||||
"""
|
||||
Return the path to the root folder under which extensions will built.
|
||||
|
||||
For each extension module built, there will be one folder underneath the
|
||||
folder returned by this function. For example, if ``p`` is the path
|
||||
@ -292,13 +292,13 @@ def get_default_build_root() -> str:
|
||||
|
||||
This directory is **user-specific** so that multiple users on the same
|
||||
machine won't meet permission issues.
|
||||
'''
|
||||
"""
|
||||
return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
|
||||
|
||||
|
||||
def check_compiler_ok_for_platform(compiler: str) -> bool:
|
||||
r'''
|
||||
Verifies that the compiler is the expected one for the current platform.
|
||||
"""
|
||||
Verify that the compiler is the expected one for the current platform.
|
||||
|
||||
Args:
|
||||
compiler (str): The compiler executable to check.
|
||||
@ -306,7 +306,7 @@ def check_compiler_ok_for_platform(compiler: str) -> bool:
|
||||
Returns:
|
||||
True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
|
||||
and always True for Windows.
|
||||
'''
|
||||
"""
|
||||
if IS_WINDOWS:
|
||||
return True
|
||||
which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
|
||||
@ -339,9 +339,8 @@ def check_compiler_ok_for_platform(compiler: str) -> bool:
|
||||
|
||||
|
||||
def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVersion]:
|
||||
r'''
|
||||
Determine if the given compiler is ABI-compatible with PyTorch alongside
|
||||
its version.
|
||||
"""
|
||||
Determine if the given compiler is ABI-compatible with PyTorch alongside its version.
|
||||
|
||||
Args:
|
||||
compiler (str): The compiler executable name to check (e.g. ``g++``).
|
||||
@ -350,7 +349,7 @@ def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVer
|
||||
Returns:
|
||||
A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,
|
||||
followed by a `TorchVersion` string that contains the compiler version separated by dots.
|
||||
'''
|
||||
"""
|
||||
if not _is_binary_build():
|
||||
return (True, TorchVersion('0.0.0'))
|
||||
if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
|
||||
@ -452,7 +451,7 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
|
||||
|
||||
|
||||
class BuildExtension(build_ext):
|
||||
r'''
|
||||
"""
|
||||
A custom :mod:`setuptools` build extension .
|
||||
|
||||
This :class:`setuptools.build_ext` subclass takes care of passing the
|
||||
@ -475,14 +474,11 @@ class BuildExtension(build_ext):
|
||||
extension. This may use up too many resources on some systems. One
|
||||
can control the number of workers by setting the `MAX_JOBS` environment
|
||||
variable to a non-negative number.
|
||||
'''
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def with_options(cls, **options):
|
||||
r'''
|
||||
Returns a subclass with alternative constructor that extends any original keyword
|
||||
arguments to the original constructor with the given options.
|
||||
'''
|
||||
"""Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options."""
|
||||
class cls_with_options(cls): # type: ignore[misc, valid-type]
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(options)
|
||||
@ -928,8 +924,8 @@ class BuildExtension(build_ext):
|
||||
|
||||
|
||||
def CppExtension(name, sources, *args, **kwargs):
|
||||
r'''
|
||||
Creates a :class:`setuptools.Extension` for C++.
|
||||
"""
|
||||
Create a :class:`setuptools.Extension` for C++.
|
||||
|
||||
Convenience method that creates a :class:`setuptools.Extension` with the
|
||||
bare minimum (but often sufficient) arguments to build a C++ extension.
|
||||
@ -953,7 +949,7 @@ def CppExtension(name, sources, *args, **kwargs):
|
||||
... cmdclass={
|
||||
... 'build_ext': BuildExtension
|
||||
... })
|
||||
'''
|
||||
"""
|
||||
include_dirs = kwargs.get('include_dirs', [])
|
||||
include_dirs += include_paths()
|
||||
kwargs['include_dirs'] = include_dirs
|
||||
@ -974,8 +970,8 @@ def CppExtension(name, sources, *args, **kwargs):
|
||||
|
||||
|
||||
def CUDAExtension(name, sources, *args, **kwargs):
|
||||
r'''
|
||||
Creates a :class:`setuptools.Extension` for CUDA/C++.
|
||||
"""
|
||||
Create a :class:`setuptools.Extension` for CUDA/C++.
|
||||
|
||||
Convenience method that creates a :class:`setuptools.Extension` with the
|
||||
bare minimum (but often sufficient) arguments to build a CUDA/C++
|
||||
@ -1072,7 +1068,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
... dlink_libraries=["dlink_lib"],
|
||||
... extra_compile_args={'cxx': ['-g'],
|
||||
... 'nvcc': ['-O2', '-rdc=true']})
|
||||
'''
|
||||
"""
|
||||
library_dirs = kwargs.get('library_dirs', [])
|
||||
library_dirs += library_paths(cuda=True)
|
||||
kwargs['library_dirs'] = library_dirs
|
||||
@ -1145,7 +1141,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
|
||||
|
||||
def include_paths(cuda: bool = False) -> List[str]:
|
||||
'''
|
||||
"""
|
||||
Get the include paths required to build a C++ or CUDA extension.
|
||||
|
||||
Args:
|
||||
@ -1153,7 +1149,7 @@ def include_paths(cuda: bool = False) -> List[str]:
|
||||
|
||||
Returns:
|
||||
A list of include path strings.
|
||||
'''
|
||||
"""
|
||||
lib_include = os.path.join(_TORCH_PATH, 'include')
|
||||
paths = [
|
||||
lib_include,
|
||||
@ -1179,7 +1175,7 @@ def include_paths(cuda: bool = False) -> List[str]:
|
||||
|
||||
|
||||
def library_paths(cuda: bool = False) -> List[str]:
|
||||
r'''
|
||||
"""
|
||||
Get the library paths required to build a C++ or CUDA extension.
|
||||
|
||||
Args:
|
||||
@ -1187,7 +1183,7 @@ def library_paths(cuda: bool = False) -> List[str]:
|
||||
|
||||
Returns:
|
||||
A list of library path strings.
|
||||
'''
|
||||
"""
|
||||
# We need to link against libtorch.so
|
||||
paths = [TORCH_LIB_PATH]
|
||||
|
||||
@ -1226,8 +1222,8 @@ def load(name,
|
||||
is_python_module=True,
|
||||
is_standalone=False,
|
||||
keep_intermediates=True):
|
||||
r'''
|
||||
Loads a PyTorch C++ extension just-in-time (JIT).
|
||||
"""
|
||||
Load a PyTorch C++ extension just-in-time (JIT).
|
||||
|
||||
To load an extension, a Ninja build file is emitted, which is used to
|
||||
compile the given sources into a dynamic library. This library is
|
||||
@ -1305,7 +1301,7 @@ def load(name,
|
||||
... sources=['extension.cpp', 'extension_kernel.cu'],
|
||||
... extra_cflags=['-O2'],
|
||||
... verbose=True)
|
||||
'''
|
||||
"""
|
||||
return _jit_compile(
|
||||
name,
|
||||
[sources] if isinstance(sources, str) else sources,
|
||||
@ -1513,7 +1509,7 @@ def load_inline(name,
|
||||
keep_intermediates=True,
|
||||
use_pch=False):
|
||||
r'''
|
||||
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
|
||||
Load a PyTorch C++ extension just-in-time (JIT) from string sources.
|
||||
|
||||
This function behaves exactly like :func:`load`, but takes its sources as
|
||||
strings rather than filenames. These strings are stored to files in the
|
||||
@ -1830,10 +1826,7 @@ def _write_ninja_file_and_build_library(
|
||||
|
||||
|
||||
def is_ninja_available():
|
||||
r'''
|
||||
Returns ``True`` if the `ninja <https://ninja-build.org/>`_ build system is
|
||||
available on the system, ``False`` otherwise.
|
||||
'''
|
||||
"""Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
|
||||
try:
|
||||
subprocess.check_output('ninja --version'.split())
|
||||
except Exception:
|
||||
@ -1843,10 +1836,7 @@ def is_ninja_available():
|
||||
|
||||
|
||||
def verify_ninja_availability():
|
||||
r'''
|
||||
Raises ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not
|
||||
available on the system, does nothing otherwise.
|
||||
'''
|
||||
"""Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise."""
|
||||
if not is_ninja_available():
|
||||
raise RuntimeError("Ninja is required to load C++ extensions")
|
||||
|
||||
@ -1916,7 +1906,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
|
||||
|
||||
|
||||
def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
||||
r'''
|
||||
"""
|
||||
Determine CUDA arch flags to use.
|
||||
|
||||
For an arch, say "6.1", the added compile flag will be
|
||||
@ -1926,7 +1916,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
||||
|
||||
See select_compute_arch.cmake for corresponding named and supported arches
|
||||
when building with CMake.
|
||||
'''
|
||||
"""
|
||||
# If cflags is given, there may already be user-provided arch flags in it
|
||||
# (from `extra_compile_args`)
|
||||
if cflags is not None:
|
||||
@ -2406,12 +2396,12 @@ def _write_ninja_file(path,
|
||||
_maybe_write(path, content)
|
||||
|
||||
def _join_cuda_home(*paths) -> str:
|
||||
r'''
|
||||
Joins paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
|
||||
"""
|
||||
Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
|
||||
|
||||
This is basically a lazy way of raising an error for missing $CUDA_HOME
|
||||
only once we need to get any CUDA-specific path.
|
||||
'''
|
||||
"""
|
||||
if CUDA_HOME is None:
|
||||
raise OSError('CUDA_HOME environment variable is not set. '
|
||||
'Please set it to your CUDA install root.')
|
||||
|
Reference in New Issue
Block a user