[BE] enable UFMT for torch/nn/functional.py (#128592)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #128596, #128594
This commit is contained in:
Xuehai Pan
2024-06-17 03:58:20 +08:00
committed by PyTorch MergeBot
parent 95ac2d6482
commit f6e6e55fa7
18 changed files with 1360 additions and 492 deletions

View File

@ -13,7 +13,8 @@ out_dims_t = Union[int, Tuple[int, ...]]
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]], flat_args: List
flat_in_dims: List[Optional[int]],
flat_args: List,
) -> int:
batch_sizes = [
arg.size(in_dim)
@ -37,7 +38,9 @@ def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(
value: Any, num_elements: int, error_message_lambda: Callable[[], str]
value: Any,
num_elements: int,
error_message_lambda: Callable[[], str],
) -> Tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
@ -49,7 +52,10 @@ def _as_tuple(
# Creates BatchedTensors for every Tensor in arg that should be batched.
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
in_dims: in_dims_t,
args: Tuple,
vmap_level: int,
func: Callable,
) -> Tuple[Tuple, int]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(