mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][PYFMT] migrate PYFMT for torch/[p-z]*/
to ruff format
(#144552)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
fd606a3a91
commit
5cedc5a0ff
@ -302,14 +302,12 @@ class BaseTorchDispatchMode(TorchDispatchMode):
|
||||
|
||||
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
|
||||
class TensorWithFlatten(Protocol):
|
||||
def __tensor_flatten__(self) -> tuple[Sequence[str], object]:
|
||||
...
|
||||
def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
# It would be really nice to be able to say that the return of
|
||||
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
|
||||
@ -318,26 +316,20 @@ class TensorWithFlatten(Protocol):
|
||||
shape: torch._C.Size
|
||||
|
||||
@overload
|
||||
def stride(self, dim: None = None) -> tuple[int, ...]:
|
||||
...
|
||||
def stride(self, dim: None = None) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def stride(self, dim: int) -> int:
|
||||
...
|
||||
def stride(self, dim: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def size(self, dim: None = None) -> tuple[int, ...]:
|
||||
...
|
||||
def size(self, dim: None = None) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def size(self, dim: int) -> int:
|
||||
...
|
||||
def size(self, dim: int) -> int: ...
|
||||
|
||||
def storage_offset(self) -> int:
|
||||
...
|
||||
def storage_offset(self) -> int: ...
|
||||
|
||||
def dim(self) -> int:
|
||||
...
|
||||
def dim(self) -> int: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -347,8 +339,7 @@ class TensorWithFlatten(Protocol):
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -359,8 +350,7 @@ class TensorWithFlatten(Protocol):
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -370,8 +360,7 @@ class TensorWithFlatten(Protocol):
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
|
Reference in New Issue
Block a user