mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix https://github.com/pytorch/pytorch/issues/145838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb853945a7
commit
e2c9d8d641
@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,29200000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38440000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25750000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,942514329,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18660000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18430000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16750000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16550000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1677000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1661000000,0.02
|
||||
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ symint_sum_loop,compile_time_instruction_count,4216000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2113000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2101000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -58,11 +58,11 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8775000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1925000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
|
||||
from torch import randn, Tensor
|
||||
@ -27,7 +27,7 @@ assert_type(TENSOR >= TENSOR, Tensor)
|
||||
assert_type(TENSOR + TENSOR, Tensor)
|
||||
assert_type(TENSOR - TENSOR, Tensor)
|
||||
assert_type(TENSOR * TENSOR, Tensor)
|
||||
assert_type(TENSOR // TENSOR, Any)
|
||||
assert_type(TENSOR // TENSOR, Tensor)
|
||||
assert_type(TENSOR / TENSOR, Tensor)
|
||||
assert_type(TENSOR % TENSOR, Tensor)
|
||||
assert_type(TENSOR**TENSOR, Tensor)
|
||||
@ -46,7 +46,7 @@ assert_type(TENSOR >= BOOL, Tensor)
|
||||
assert_type(TENSOR + BOOL, Tensor)
|
||||
assert_type(TENSOR - BOOL, Tensor)
|
||||
assert_type(TENSOR * BOOL, Tensor)
|
||||
assert_type(TENSOR // BOOL, Any)
|
||||
assert_type(TENSOR // BOOL, Tensor)
|
||||
assert_type(TENSOR / BOOL, Tensor)
|
||||
assert_type(TENSOR % BOOL, Tensor)
|
||||
assert_type(TENSOR**BOOL, Tensor)
|
||||
@ -63,14 +63,14 @@ assert_type(BOOL > TENSOR, Tensor)
|
||||
assert_type(BOOL <= TENSOR, Tensor)
|
||||
assert_type(BOOL >= TENSOR, Tensor)
|
||||
assert_type(BOOL + TENSOR, Tensor)
|
||||
assert_type(BOOL - TENSOR, Any)
|
||||
assert_type(BOOL - TENSOR, Tensor)
|
||||
assert_type(BOOL * TENSOR, Tensor)
|
||||
assert_type(BOOL // TENSOR, Any)
|
||||
assert_type(BOOL / TENSOR, Any)
|
||||
assert_type(BOOL % TENSOR, Any)
|
||||
assert_type(BOOL**TENSOR, Any)
|
||||
assert_type(BOOL << TENSOR, Any)
|
||||
assert_type(BOOL >> TENSOR, Any)
|
||||
assert_type(BOOL // TENSOR, Tensor)
|
||||
assert_type(BOOL / TENSOR, Tensor)
|
||||
assert_type(BOOL % TENSOR, Tensor)
|
||||
assert_type(BOOL**TENSOR, Tensor)
|
||||
assert_type(BOOL << TENSOR, Tensor)
|
||||
assert_type(BOOL >> TENSOR, Tensor)
|
||||
assert_type(BOOL & TENSOR, Tensor)
|
||||
assert_type(BOOL | TENSOR, Tensor)
|
||||
assert_type(BOOL ^ TENSOR, Tensor)
|
||||
@ -84,7 +84,7 @@ assert_type(TENSOR >= INT, Tensor)
|
||||
assert_type(TENSOR + INT, Tensor)
|
||||
assert_type(TENSOR - INT, Tensor)
|
||||
assert_type(TENSOR * INT, Tensor)
|
||||
assert_type(TENSOR // INT, Any)
|
||||
assert_type(TENSOR // INT, Tensor)
|
||||
assert_type(TENSOR / INT, Tensor)
|
||||
assert_type(TENSOR % INT, Tensor)
|
||||
assert_type(TENSOR**INT, Tensor)
|
||||
@ -101,14 +101,14 @@ assert_type(INT > TENSOR, Tensor)
|
||||
assert_type(INT <= TENSOR, Tensor)
|
||||
assert_type(INT >= TENSOR, Tensor)
|
||||
assert_type(INT + TENSOR, Tensor)
|
||||
assert_type(INT - TENSOR, Any)
|
||||
assert_type(INT - TENSOR, Tensor)
|
||||
assert_type(INT * TENSOR, Tensor)
|
||||
assert_type(INT // TENSOR, Any)
|
||||
assert_type(INT / TENSOR, Any)
|
||||
assert_type(INT % TENSOR, Any)
|
||||
assert_type(INT**TENSOR, Any)
|
||||
assert_type(INT << TENSOR, Any)
|
||||
assert_type(INT >> TENSOR, Any)
|
||||
assert_type(INT // TENSOR, Tensor)
|
||||
assert_type(INT / TENSOR, Tensor)
|
||||
assert_type(INT % TENSOR, Tensor)
|
||||
assert_type(INT**TENSOR, Tensor)
|
||||
assert_type(INT << TENSOR, Tensor)
|
||||
assert_type(INT >> TENSOR, Tensor)
|
||||
assert_type(INT & TENSOR, Tensor)
|
||||
assert_type(INT | TENSOR, Tensor)
|
||||
assert_type(INT ^ TENSOR, Tensor)
|
||||
@ -122,7 +122,7 @@ assert_type(TENSOR >= FLOAT, Tensor)
|
||||
assert_type(TENSOR + FLOAT, Tensor)
|
||||
assert_type(TENSOR - FLOAT, Tensor)
|
||||
assert_type(TENSOR * FLOAT, Tensor)
|
||||
assert_type(TENSOR // FLOAT, Any)
|
||||
assert_type(TENSOR // FLOAT, Tensor)
|
||||
assert_type(TENSOR / FLOAT, Tensor)
|
||||
assert_type(TENSOR % FLOAT, Tensor)
|
||||
assert_type(TENSOR**FLOAT, Tensor)
|
||||
@ -139,14 +139,17 @@ assert_type(FLOAT > TENSOR, Tensor)
|
||||
assert_type(FLOAT <= TENSOR, Tensor)
|
||||
assert_type(FLOAT >= TENSOR, Tensor)
|
||||
assert_type(FLOAT + TENSOR, Tensor)
|
||||
assert_type(FLOAT - TENSOR, Any)
|
||||
assert_type(FLOAT - TENSOR, Tensor)
|
||||
assert_type(FLOAT * TENSOR, Tensor)
|
||||
assert_type(FLOAT // TENSOR, Any)
|
||||
assert_type(FLOAT / TENSOR, Any)
|
||||
assert_type(FLOAT % TENSOR, Any)
|
||||
assert_type(FLOAT**TENSOR, Any)
|
||||
assert_type(FLOAT << TENSOR, Any)
|
||||
assert_type(FLOAT >> TENSOR, Any)
|
||||
assert_type(FLOAT // TENSOR, Tensor)
|
||||
assert_type(FLOAT / TENSOR, Tensor)
|
||||
assert_type(FLOAT % TENSOR, Tensor)
|
||||
assert_type(FLOAT**TENSOR, Tensor)
|
||||
assert_type(FLOAT << TENSOR, Tensor)
|
||||
assert_type(FLOAT >> TENSOR, Tensor)
|
||||
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
|
||||
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
|
||||
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
|
||||
|
||||
|
||||
NUMBER: TypeAlias = Union[int, float, bool]
|
||||
@ -370,38 +373,3 @@ assert_type(BOOL**BINARY, Binary)
|
||||
assert_type(BOOL >> BINARY, Binary)
|
||||
assert_type(BOOL - BINARY, Binary)
|
||||
assert_type(BOOL ^ BINARY, Binary)
|
||||
|
||||
# Tensor operators whose types could be improved
|
||||
# This is the "diff" of the first and second sections.
|
||||
|
||||
assert_type(BOOL // TENSOR, Any)
|
||||
assert_type(FLOAT // TENSOR, Any)
|
||||
assert_type(INT // TENSOR, Any)
|
||||
assert_type(TENSOR // BOOL, Any)
|
||||
assert_type(TENSOR // FLOAT, Any)
|
||||
assert_type(TENSOR // INT, Any)
|
||||
assert_type(TENSOR // TENSOR, Any)
|
||||
|
||||
assert_type(BOOL**TENSOR, Any)
|
||||
assert_type(FLOAT**TENSOR, Any)
|
||||
assert_type(INT**TENSOR, Any)
|
||||
|
||||
assert_type(BOOL - TENSOR, Any)
|
||||
assert_type(FLOAT - TENSOR, Any)
|
||||
assert_type(INT - TENSOR, Any)
|
||||
|
||||
assert_type(BOOL / TENSOR, Any)
|
||||
assert_type(FLOAT / TENSOR, Any)
|
||||
assert_type(INT / TENSOR, Any)
|
||||
|
||||
assert_type(BOOL % TENSOR, Any)
|
||||
assert_type(FLOAT % TENSOR, Any)
|
||||
assert_type(INT % TENSOR, Any)
|
||||
|
||||
assert_type(BOOL << TENSOR, Any)
|
||||
assert_type(FLOAT << TENSOR, Any)
|
||||
assert_type(INT << TENSOR, Any)
|
||||
|
||||
assert_type(BOOL >> TENSOR, Any)
|
||||
assert_type(FLOAT >> TENSOR, Any)
|
||||
assert_type(INT >> TENSOR, Any)
|
||||
|
@ -814,7 +814,7 @@ def slice_scatter(
|
||||
if start == 0 and end == dim_size and step == 1:
|
||||
return src.clone()
|
||||
|
||||
indices = [None] * input.dim()
|
||||
indices: list[Optional[Tensor]] = [None] * input.dim()
|
||||
idx = torch.arange(dim_size, device=input.device)
|
||||
indices[dim] = (idx - start) // step
|
||||
|
||||
@ -1677,6 +1677,7 @@ def native_layer_norm_backward(
|
||||
)
|
||||
mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
|
||||
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
|
||||
assert input_cast is not None
|
||||
x_hat = (input_cast - mean) * rstd
|
||||
if weight_cast is not None:
|
||||
grad_x_hat = grad_out_cast * weight_cast
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -69,12 +69,10 @@ def philox_rand_offset(
|
||||
curand4_engine_calls = 4
|
||||
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
|
||||
grid_size = (numel + block_size - 1) // block_size
|
||||
num = cast(int, numel)
|
||||
grid_size = (num + block_size - 1) // block_size
|
||||
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
|
||||
offset = (
|
||||
(numel - 1) // (block_size * grid_size * unroll) + 1
|
||||
) * curand4_engine_calls
|
||||
return offset
|
||||
return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls
|
||||
|
||||
|
||||
def register_philox_rand():
|
||||
|
@ -6,7 +6,8 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
@ -27,16 +28,21 @@ from torch.overrides import (
|
||||
)
|
||||
|
||||
|
||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
|
||||
assigned = functools.WRAPPER_ASSIGNMENTS
|
||||
_P = ParamSpec("_P")
|
||||
_TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase)
|
||||
|
||||
@functools.wraps(f, assigned=assigned)
|
||||
def wrapped(*args, **kwargs):
|
||||
|
||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||
f: Callable[Concatenate[_TensorLike, _P], "Tensor"],
|
||||
) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]:
|
||||
@functools.wraps(f)
|
||||
def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor":
|
||||
try:
|
||||
# See https://github.com/pytorch/pytorch/issues/75462
|
||||
if has_torch_function(args):
|
||||
return handle_torch_function(wrapped, args, *args, **kwargs)
|
||||
return f(*args, **kwargs)
|
||||
sargs = self, *args
|
||||
if has_torch_function(sargs):
|
||||
return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
|
||||
return f(self, *args, **kwargs)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
|
||||
@ -1093,11 +1099,11 @@ class Tensor(torch._C.TensorBase):
|
||||
)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rsub__(self, other):
|
||||
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return _C._VariableFunctions.rsub(self, other)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rdiv__(self, other):
|
||||
def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return self.reciprocal() * other
|
||||
|
||||
__rtruediv__ = __rdiv__
|
||||
@ -1112,12 +1118,13 @@ class Tensor(torch._C.TensorBase):
|
||||
_C.TensorBase.pow
|
||||
),
|
||||
)
|
||||
|
||||
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||
_C.TensorBase.pow_
|
||||
)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rmod__(self, other):
|
||||
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return torch.remainder(other, self)
|
||||
|
||||
def __format__(self, format_spec):
|
||||
@ -1130,27 +1137,33 @@ class Tensor(torch._C.TensorBase):
|
||||
return object.__format__(self, format_spec)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rpow__(self, other):
|
||||
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return torch.pow(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __floordiv__(self, other):
|
||||
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||
# TODO(rec): the superclass says it accepts complex here,
|
||||
# but torch.floor_divide says it doesn't.
|
||||
return torch.floor_divide(self, other)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rfloordiv__(self, other):
|
||||
def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||
return torch.floor_divide(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rlshift__(self, other):
|
||||
def __rlshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
return torch.bitwise_left_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rrshift__(self, other):
|
||||
def __rrshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
return torch.bitwise_right_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rmatmul__(self, other):
|
||||
def __rmatmul__(self, other: "Tensor") -> "Tensor":
|
||||
return torch.matmul(other, self)
|
||||
|
||||
__pos__ = _C.TensorBase.positive
|
||||
|
@ -631,6 +631,7 @@ def powerSGD_hook(
|
||||
|
||||
if state.use_error_feedback:
|
||||
# Memorize the local errors.
|
||||
assert input_tensor_cp is not None
|
||||
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
|
||||
if not state.warm_start:
|
||||
state.p_memory_dict.clear()
|
||||
@ -843,6 +844,7 @@ def batched_powerSGD_hook(
|
||||
|
||||
if state.use_error_feedback:
|
||||
# Memorize the local errors.
|
||||
assert input_tensor_cp is not None
|
||||
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
|
||||
# Removing this seemingly unnecessary sync somehow may cause failures.
|
||||
# See: https://github.com/pytorch/pytorch/pull/54838
|
||||
|
@ -1,4 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions.distribution import Distribution
|
||||
@ -55,7 +57,7 @@ class ExponentialFamily(Distribution):
|
||||
"""
|
||||
Method to compute the entropy using Bregman divergence of the log normalizer.
|
||||
"""
|
||||
result = -self._mean_carrier_measure
|
||||
result: Union[Tensor, float] = -self._mean_carrier_measure
|
||||
nparams = [p.detach().requires_grad_() for p in self._natural_params]
|
||||
lg_normal = self._log_normalizer(*nparams)
|
||||
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
|
||||
|
@ -170,7 +170,7 @@ class TransformedDistribution(Distribution):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
event_dim = len(self.event_shape)
|
||||
log_prob = 0.0
|
||||
log_prob: Union[Tensor, float] = 0.0
|
||||
y = value
|
||||
for transform in reversed(self.transforms):
|
||||
x = transform.inv(y)
|
||||
|
@ -459,9 +459,11 @@ def _single_tensor_adam(
|
||||
# expavg.lerp(grad^2, 1-beta2)
|
||||
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(
|
||||
grad, grad, value=cast(float, 1 - beta2)
|
||||
)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
|
||||
|
||||
if capturable or differentiable:
|
||||
step = step_t
|
||||
@ -532,7 +534,7 @@ def _single_tensor_adam(
|
||||
else:
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
|
||||
|
||||
# Lastly, switch back to complex view
|
||||
if amsgrad and torch.is_complex(params[i]):
|
||||
@ -686,7 +688,9 @@ def _multi_tensor_adam(
|
||||
# Decay the first and second moment running average coefficient
|
||||
# Use device beta1 if beta1 is a tensor to ensure all
|
||||
# tensors are on the same device
|
||||
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
|
||||
torch._foreach_lerp_(
|
||||
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
|
||||
)
|
||||
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
||||
|
@ -371,7 +371,9 @@ def _single_tensor_nadam(
|
||||
grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
|
||||
)
|
||||
param.addcdiv_(
|
||||
exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
|
||||
exp_avg,
|
||||
denom,
|
||||
value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)),
|
||||
)
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ import math
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Callable, cast, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -69,7 +69,9 @@ def get_swa_multi_avg_fn():
|
||||
averaged_param_list[0]
|
||||
):
|
||||
torch._foreach_lerp_(
|
||||
averaged_param_list, current_param_list, 1 / (num_averaged + 1)
|
||||
averaged_param_list,
|
||||
current_param_list,
|
||||
cast(float, 1 / (num_averaged + 1)),
|
||||
)
|
||||
else:
|
||||
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
|
||||
|
Reference in New Issue
Block a user