Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)

Fix https://github.com/pytorch/pytorch/issues/145838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845
Approved by: https://github.com/Skylion007
This commit is contained in:
Tom Ritchford
2025-06-24 11:29:29 +00:00
committed by PyTorch MergeBot
parent cb853945a7
commit e2c9d8d641
11 changed files with 95 additions and 103 deletions

View File

@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
add_loop_inductor,compile_time_instruction_count,29200000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38440000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,25750000000,0.015
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,942514329,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18660000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18430000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16750000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16550000000,0.015
@ -34,7 +34,7 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
update_hint_regression,compile_time_instruction_count,1677000000,0.02
update_hint_regression,compile_time_instruction_count,1661000000,0.02
@ -50,7 +50,7 @@ symint_sum_loop,compile_time_instruction_count,4216000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2113000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2101000000,0.015
@ -58,11 +58,11 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8775000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1925000000,0.015

1 add_loop_eager compile_time_instruction_count 2937000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 942514329 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18660000000 18430000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16750000000 16550000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10370000000 0.2
10 update_hint_regression compile_time_instruction_count 1677000000 1661000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 984411080 0.015
12 symint_sum compile_time_instruction_count 3252000000 0.015
13 symint_sum_loop compile_time_instruction_count 4216000000 0.015
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2113000000 2101000000 0.015
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6022000000 0.015
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 8844000000 8775000000 0.015
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1963000000 1925000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3875000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10420000000 0.015
20
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
50
51
52
53
54
55
56
58
59
60
61
62
63
64
65
66
67
68

View File

@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor
@ -27,7 +27,7 @@ assert_type(TENSOR >= TENSOR, Tensor)
assert_type(TENSOR + TENSOR, Tensor)
assert_type(TENSOR - TENSOR, Tensor)
assert_type(TENSOR * TENSOR, Tensor)
assert_type(TENSOR // TENSOR, Any)
assert_type(TENSOR // TENSOR, Tensor)
assert_type(TENSOR / TENSOR, Tensor)
assert_type(TENSOR % TENSOR, Tensor)
assert_type(TENSOR**TENSOR, Tensor)
@ -46,7 +46,7 @@ assert_type(TENSOR >= BOOL, Tensor)
assert_type(TENSOR + BOOL, Tensor)
assert_type(TENSOR - BOOL, Tensor)
assert_type(TENSOR * BOOL, Tensor)
assert_type(TENSOR // BOOL, Any)
assert_type(TENSOR // BOOL, Tensor)
assert_type(TENSOR / BOOL, Tensor)
assert_type(TENSOR % BOOL, Tensor)
assert_type(TENSOR**BOOL, Tensor)
@ -63,14 +63,14 @@ assert_type(BOOL > TENSOR, Tensor)
assert_type(BOOL <= TENSOR, Tensor)
assert_type(BOOL >= TENSOR, Tensor)
assert_type(BOOL + TENSOR, Tensor)
assert_type(BOOL - TENSOR, Any)
assert_type(BOOL - TENSOR, Tensor)
assert_type(BOOL * TENSOR, Tensor)
assert_type(BOOL // TENSOR, Any)
assert_type(BOOL / TENSOR, Any)
assert_type(BOOL % TENSOR, Any)
assert_type(BOOL**TENSOR, Any)
assert_type(BOOL << TENSOR, Any)
assert_type(BOOL >> TENSOR, Any)
assert_type(BOOL // TENSOR, Tensor)
assert_type(BOOL / TENSOR, Tensor)
assert_type(BOOL % TENSOR, Tensor)
assert_type(BOOL**TENSOR, Tensor)
assert_type(BOOL << TENSOR, Tensor)
assert_type(BOOL >> TENSOR, Tensor)
assert_type(BOOL & TENSOR, Tensor)
assert_type(BOOL | TENSOR, Tensor)
assert_type(BOOL ^ TENSOR, Tensor)
@ -84,7 +84,7 @@ assert_type(TENSOR >= INT, Tensor)
assert_type(TENSOR + INT, Tensor)
assert_type(TENSOR - INT, Tensor)
assert_type(TENSOR * INT, Tensor)
assert_type(TENSOR // INT, Any)
assert_type(TENSOR // INT, Tensor)
assert_type(TENSOR / INT, Tensor)
assert_type(TENSOR % INT, Tensor)
assert_type(TENSOR**INT, Tensor)
@ -101,14 +101,14 @@ assert_type(INT > TENSOR, Tensor)
assert_type(INT <= TENSOR, Tensor)
assert_type(INT >= TENSOR, Tensor)
assert_type(INT + TENSOR, Tensor)
assert_type(INT - TENSOR, Any)
assert_type(INT - TENSOR, Tensor)
assert_type(INT * TENSOR, Tensor)
assert_type(INT // TENSOR, Any)
assert_type(INT / TENSOR, Any)
assert_type(INT % TENSOR, Any)
assert_type(INT**TENSOR, Any)
assert_type(INT << TENSOR, Any)
assert_type(INT >> TENSOR, Any)
assert_type(INT // TENSOR, Tensor)
assert_type(INT / TENSOR, Tensor)
assert_type(INT % TENSOR, Tensor)
assert_type(INT**TENSOR, Tensor)
assert_type(INT << TENSOR, Tensor)
assert_type(INT >> TENSOR, Tensor)
assert_type(INT & TENSOR, Tensor)
assert_type(INT | TENSOR, Tensor)
assert_type(INT ^ TENSOR, Tensor)
@ -122,7 +122,7 @@ assert_type(TENSOR >= FLOAT, Tensor)
assert_type(TENSOR + FLOAT, Tensor)
assert_type(TENSOR - FLOAT, Tensor)
assert_type(TENSOR * FLOAT, Tensor)
assert_type(TENSOR // FLOAT, Any)
assert_type(TENSOR // FLOAT, Tensor)
assert_type(TENSOR / FLOAT, Tensor)
assert_type(TENSOR % FLOAT, Tensor)
assert_type(TENSOR**FLOAT, Tensor)
@ -139,14 +139,17 @@ assert_type(FLOAT > TENSOR, Tensor)
assert_type(FLOAT <= TENSOR, Tensor)
assert_type(FLOAT >= TENSOR, Tensor)
assert_type(FLOAT + TENSOR, Tensor)
assert_type(FLOAT - TENSOR, Any)
assert_type(FLOAT - TENSOR, Tensor)
assert_type(FLOAT * TENSOR, Tensor)
assert_type(FLOAT // TENSOR, Any)
assert_type(FLOAT / TENSOR, Any)
assert_type(FLOAT % TENSOR, Any)
assert_type(FLOAT**TENSOR, Any)
assert_type(FLOAT << TENSOR, Any)
assert_type(FLOAT >> TENSOR, Any)
assert_type(FLOAT // TENSOR, Tensor)
assert_type(FLOAT / TENSOR, Tensor)
assert_type(FLOAT % TENSOR, Tensor)
assert_type(FLOAT**TENSOR, Tensor)
assert_type(FLOAT << TENSOR, Tensor)
assert_type(FLOAT >> TENSOR, Tensor)
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
NUMBER: TypeAlias = Union[int, float, bool]
@ -370,38 +373,3 @@ assert_type(BOOL**BINARY, Binary)
assert_type(BOOL >> BINARY, Binary)
assert_type(BOOL - BINARY, Binary)
assert_type(BOOL ^ BINARY, Binary)
# Tensor operators whose types could be improved
# This is the "diff" of the first and second sections.
assert_type(BOOL // TENSOR, Any)
assert_type(FLOAT // TENSOR, Any)
assert_type(INT // TENSOR, Any)
assert_type(TENSOR // BOOL, Any)
assert_type(TENSOR // FLOAT, Any)
assert_type(TENSOR // INT, Any)
assert_type(TENSOR // TENSOR, Any)
assert_type(BOOL**TENSOR, Any)
assert_type(FLOAT**TENSOR, Any)
assert_type(INT**TENSOR, Any)
assert_type(BOOL - TENSOR, Any)
assert_type(FLOAT - TENSOR, Any)
assert_type(INT - TENSOR, Any)
assert_type(BOOL / TENSOR, Any)
assert_type(FLOAT / TENSOR, Any)
assert_type(INT / TENSOR, Any)
assert_type(BOOL % TENSOR, Any)
assert_type(FLOAT % TENSOR, Any)
assert_type(INT % TENSOR, Any)
assert_type(BOOL << TENSOR, Any)
assert_type(FLOAT << TENSOR, Any)
assert_type(INT << TENSOR, Any)
assert_type(BOOL >> TENSOR, Any)
assert_type(FLOAT >> TENSOR, Any)
assert_type(INT >> TENSOR, Any)

View File

@ -814,7 +814,7 @@ def slice_scatter(
if start == 0 and end == dim_size and step == 1:
return src.clone()
indices = [None] * input.dim()
indices: list[Optional[Tensor]] = [None] * input.dim()
idx = torch.arange(dim_size, device=input.device)
indices[dim] = (idx - start) // step
@ -1677,6 +1677,7 @@ def native_layer_norm_backward(
)
mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
assert input_cast is not None
x_hat = (input_cast - mean) * rstd
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Optional
from typing import cast, Optional
import torch
import torch.utils._pytree as pytree
@ -69,12 +69,10 @@ def philox_rand_offset(
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
grid_size = (numel + block_size - 1) // block_size
num = cast(int, numel)
grid_size = (num + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
return offset
return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls
def register_philox_rand():

View File

@ -6,7 +6,8 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, cast, Optional, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec
import torch
import torch._C as _C
@ -27,16 +28,21 @@ from torch.overrides import (
)
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
assigned = functools.WRAPPER_ASSIGNMENTS
_P = ParamSpec("_P")
_TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase)
@functools.wraps(f, assigned=assigned)
def wrapped(*args, **kwargs):
def _handle_torch_function_and_wrap_type_error_to_not_implemented(
f: Callable[Concatenate[_TensorLike, _P], "Tensor"],
) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]:
@functools.wraps(f)
def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor":
try:
# See https://github.com/pytorch/pytorch/issues/75462
if has_torch_function(args):
return handle_torch_function(wrapped, args, *args, **kwargs)
return f(*args, **kwargs)
sargs = self, *args
if has_torch_function(sargs):
return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
return f(self, *args, **kwargs)
except TypeError:
return NotImplemented
@ -1093,11 +1099,11 @@ class Tensor(torch._C.TensorBase):
)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rsub__(self, other):
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return _C._VariableFunctions.rsub(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rdiv__(self, other):
def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return self.reciprocal() * other
__rtruediv__ = __rdiv__
@ -1112,12 +1118,13 @@ class Tensor(torch._C.TensorBase):
_C.TensorBase.pow
),
)
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C.TensorBase.pow_
)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmod__(self, other):
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.remainder(other, self)
def __format__(self, format_spec):
@ -1130,27 +1137,33 @@ class Tensor(torch._C.TensorBase):
return object.__format__(self, format_spec)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rpow__(self, other):
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.pow(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other):
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
# TODO(rec): the superclass says it accepts complex here,
# but torch.floor_divide says it doesn't.
return torch.floor_divide(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rfloordiv__(self, other):
def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
return torch.floor_divide(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rlshift__(self, other):
def __rlshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
return torch.bitwise_left_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rrshift__(self, other):
def __rrshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
return torch.bitwise_right_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmatmul__(self, other):
def __rmatmul__(self, other: "Tensor") -> "Tensor":
return torch.matmul(other, self)
__pos__ = _C.TensorBase.positive

View File

@ -631,6 +631,7 @@ def powerSGD_hook(
if state.use_error_feedback:
# Memorize the local errors.
assert input_tensor_cp is not None
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if not state.warm_start:
state.p_memory_dict.clear()
@ -843,6 +844,7 @@ def batched_powerSGD_hook(
if state.use_error_feedback:
# Memorize the local errors.
assert input_tensor_cp is not None
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
# Removing this seemingly unnecessary sync somehow may cause failures.
# See: https://github.com/pytorch/pytorch/pull/54838

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Union
import torch
from torch import Tensor
from torch.distributions.distribution import Distribution
@ -55,7 +57,7 @@ class ExponentialFamily(Distribution):
"""
Method to compute the entropy using Bregman divergence of the log normalizer.
"""
result = -self._mean_carrier_measure
result: Union[Tensor, float] = -self._mean_carrier_measure
nparams = [p.detach().requires_grad_() for p in self._natural_params]
lg_normal = self._log_normalizer(*nparams)
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)

View File

@ -170,7 +170,7 @@ class TransformedDistribution(Distribution):
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
log_prob: Union[Tensor, float] = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)

View File

@ -459,9 +459,11 @@ def _single_tensor_adam(
# expavg.lerp(grad^2, 1-beta2)
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=cast(float, 1 - beta2)
)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
if capturable or differentiable:
step = step_t
@ -532,7 +534,7 @@ def _single_tensor_adam(
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)
param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
@ -686,7 +688,9 @@ def _multi_tensor_adam(
# Decay the first and second moment running average coefficient
# Use device beta1 if beta1 is a tensor to ensure all
# tensors are on the same device
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
torch._foreach_lerp_(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
torch._foreach_mul_(device_exp_avg_sqs, beta2)

View File

@ -371,7 +371,9 @@ def _single_tensor_nadam(
grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
)
param.addcdiv_(
exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
exp_avg,
denom,
value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)),
)

View File

@ -6,7 +6,7 @@ import math
import warnings
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, cast, Literal, Optional, Union
import torch
from torch import Tensor
@ -69,7 +69,9 @@ def get_swa_multi_avg_fn():
averaged_param_list[0]
):
torch._foreach_lerp_(
averaged_param_list, current_param_list, 1 / (num_averaged + 1)
averaged_param_list,
current_param_list,
cast(float, 1 / (num_averaged + 1)),
)
else:
diffs = torch._foreach_sub(current_param_list, averaged_param_list)