Add ScalarList overload to _foreach_lerp (#134482)

Related:
- https://github.com/pytorch/pytorch/issues/133367

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134482
Approved by: https://github.com/janeyx99
This commit is contained in:
Masaki Kozuki
2024-11-12 19:03:38 +00:00
committed by PyTorch MergeBot
parent 7624d625c0
commit 6a368b3fc5
10 changed files with 245 additions and 30 deletions

View File

@ -1534,6 +1534,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace):
or (isinstance(sample.args[-1], complex))
)
if rhs_arg_has_complex_number and dtype == torch.float64:
if op.name == "_foreach_lerp":
return False, "value cannot be converted to type double without overflow"
if op.name in (
"_foreach_clamp_max",
"_foreach_clamp_min",