mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Fixes: https://github.com/pytorch/pytorch/issues/19045 Please review: VitalyFedyunin ngimel This is independent on the #18649 series. This will cause merge conflicts in #18649 series, but please merge this first, and I will resolve the merge conflicts there. The new feature is exposed in `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon`. But not at `torch.unique` yet. I will take care of the API after #18649 series get merged completely. Benchmark on a tensor of shape `torch.Size([15320, 2])`: ```python print(torch.__version__) print() a = tensor.sort().values.to('cpu') print('cpu, sorted_input=False:') %timeit torch._unique2_temporary_will_remove_soon(a) %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True) %timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True) %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True) print() print('cpu, sorted_input=True:') %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True) print() a = a.to('cuda') print('cuda, sorted_input=False:') %timeit torch._unique2_temporary_will_remove_soon(a); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True); torch.cuda.synchronize() print() print('cuda, sorted_input=True:') %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` 1.1.0a0+2addccc cpu, sorted_input=False: 340 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 717 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 52.3 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 52.3 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) cpu, sorted_input=True: 32.8 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 49.9 µs ± 557 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 51.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 78 µs ± 782 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cuda, sorted_input=False: 213 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 291 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 250 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 321 µs ± 1.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) cuda, sorted_input=True: 45.6 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 110 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 82 µs ± 857 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 143 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ```python print(torch.__version__) print() a1, a2 = tensor.unbind(1) indices = (a1 * tensor.max() + a2).sort().indices a = tensor.index_select(0, indices).to('cpu') print('cpu, sorted_input=False:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True) print() print('cpu, sorted_input=True:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True) print() a = a.to('cuda') print('cuda, sorted_input=False:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True); torch.cuda.synchronize() print() print('cuda, sorted_input=True:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` cpu, sorted_input=False: 55.4 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.8 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.2 ms ± 402 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.1 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) cpu, sorted_input=True: 54.7 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 54.5 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 54.9 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) cuda, sorted_input=False: 171 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 220 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 203 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 251 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) cuda, sorted_input=True: 59.6 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 113 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 93.2 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 147 µs ± 2.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` The CPU implementation of `unique_dim` is super slow, see https://github.com/pytorch/pytorch/issues/18987, but this PR will not worry about this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19060 Differential Revision: D14866909 Pulled By: ezyang fbshipit-source-id: d20012cec68c37b05cf770a6f4d6524f910b950f
108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
# ${generated_comment}
|
|
|
|
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
|
|
from torch._six import inf
|
|
|
|
import builtins
|
|
|
|
# These identifiers are reexported from other modules. These modules
|
|
# are not mypy-clean yet, so in order to use this stub file usefully
|
|
# from mypy you will need to specify --follow-imports=silent.
|
|
# Not all is lost: these imports still enable IDEs like PyCharm to offer
|
|
# autocomplete.
|
|
#
|
|
# Note: Why does the syntax here look so strange? Import visibility
|
|
# rules in stubs are different from normal Python files! You must use
|
|
# 'from ... import ... as ...' syntax to cause an identifier to be
|
|
# exposed (or use a wildcard); regular syntax is not exposed.
|
|
from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
|
|
manual_seed as manual_seed, initial_seed as initial_seed
|
|
from ._tensor_str import set_printoptions as set_printoptions
|
|
from .functional import *
|
|
from .serialization import save as save, load as load
|
|
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
|
|
set_grad_enabled as set_grad_enabled
|
|
from . import cuda as cuda
|
|
from . import optim as optim
|
|
|
|
class dtype: ...
|
|
|
|
class layout: ...
|
|
|
|
strided : layout = ...
|
|
|
|
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
|
# is necessary
|
|
_int = builtins.int
|
|
_float = builtins.float
|
|
|
|
class device:
|
|
type: str
|
|
index: _int
|
|
|
|
@overload
|
|
def __init__(self, device: Union[_int, str]) -> None: ...
|
|
|
|
@overload
|
|
def __init__(self, type: str, index: _int) -> None: ...
|
|
|
|
class Generator: ...
|
|
|
|
class Size(tuple): ...
|
|
|
|
class Storage: ...
|
|
|
|
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
|
# is necessary
|
|
_dtype = dtype
|
|
_device = device
|
|
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
|
|
|
# Meta-type for "numeric" things; matches our docs
|
|
Number = Union[builtins.int, builtins.float]
|
|
|
|
# TODO: One downside of doing it this way, is direct use of
|
|
# torch.tensor.Tensor doesn't get type annotations. Nobody
|
|
# should really do that, so maybe this is not so bad.
|
|
class Tensor:
|
|
dtype: _dtype = ...
|
|
shape: Size = ...
|
|
device: _device = ...
|
|
requires_grad: bool = ...
|
|
grad: Optional[Tensor] = ...
|
|
|
|
${tensor_method_hints}
|
|
|
|
# Manually defined methods from torch/tensor.py
|
|
def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ...
|
|
def register_hook(self, hook: Callable) -> Any: ...
|
|
def retain_grad(self) -> None: ...
|
|
def is_pinned(self) -> bool: ...
|
|
def is_shared(self) -> bool: ...
|
|
def share_memory_(self) -> None: ...
|
|
# TODO: fill in the types for these, or otherwise figure out some
|
|
# way to not have to write these out again...
|
|
def norm(self, p="fro", dim=None, keepdim=False): ...
|
|
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
|
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
|
|
def split(self, split_size, dim=0): ...
|
|
def unique(self, sorted=True, return_inverse=False, dim=None): ...
|
|
def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
|
|
def lu(self, pivot=True, get_infos=False): ...
|
|
|
|
${function_hints}
|
|
|
|
${legacy_class_hints}
|
|
|
|
${dtype_class_hints}
|
|
|
|
# Pure Python functions defined in torch/__init__.py
|
|
|
|
def typename(obj) -> str: ...
|
|
def is_tensor(obj) -> bool: ...
|
|
def is_storage(obj) -> bool: ...
|
|
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
|
|
def set_default_dtype(d : _dtype) -> None: ...
|
|
def manager_path() -> str: ...
|
|
def compiled_with_cxx11_abi() -> bool: ...
|