mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] enable UFMT for torch/nn/functional.py
(#128592)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #128596, #128594
This commit is contained in:
committed by
PyTorch MergeBot
parent
95ac2d6482
commit
f6e6e55fa7
@ -13,7 +13,8 @@ out_dims_t = Union[int, Tuple[int, ...]]
|
||||
|
||||
# Checks that all args-to-be-batched have the same batch dim size
|
||||
def _validate_and_get_batch_size(
|
||||
flat_in_dims: List[Optional[int]], flat_args: List
|
||||
flat_in_dims: List[Optional[int]],
|
||||
flat_args: List,
|
||||
) -> int:
|
||||
batch_sizes = [
|
||||
arg.size(in_dim)
|
||||
@ -37,7 +38,9 @@ def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
# If value is a tuple, check it has length `num_elements`.
|
||||
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
|
||||
def _as_tuple(
|
||||
value: Any, num_elements: int, error_message_lambda: Callable[[], str]
|
||||
value: Any,
|
||||
num_elements: int,
|
||||
error_message_lambda: Callable[[], str],
|
||||
) -> Tuple:
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
@ -49,7 +52,10 @@ def _as_tuple(
|
||||
# Creates BatchedTensors for every Tensor in arg that should be batched.
|
||||
# Returns the (potentially) batched arguments and the batch_size.
|
||||
def _create_batched_inputs(
|
||||
in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
|
||||
in_dims: in_dims_t,
|
||||
args: Tuple,
|
||||
vmap_level: int,
|
||||
func: Callable,
|
||||
) -> Tuple[Tuple, int]:
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user