Fix aten.logspace decomposition (#105201)

Fixes #104118

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105201
Approved by: https://github.com/ezyang
This commit is contained in:
Yanbo Liang
2023-07-22 04:10:20 +00:00
committed by PyTorch MergeBot
parent 5afc2f5069
commit 0ad93a3d56
3 changed files with 24 additions and 5 deletions

View File

@ -1212,6 +1212,16 @@ class TestRefs(TestCase):
expect = torch.unbind(a, 1)
self.assertEqual(actual, expect)
def test_logspace_with_complex_input(self):
actual = refs.logspace(2, 10 + 5j, steps=5)
expect = torch.logspace(2, 10 + 5j, steps=5)
self.assertEqual(actual, expect)
def test_linspace_with_complex_input(self):
actual = refs.linspace(2, 10 + 5j, steps=5)
expect = torch.linspace(2, 10 + 5j, steps=5)
self.assertEqual(actual, expect)
instantiate_device_type_tests(TestRefs, globals())

View File

@ -4619,20 +4619,29 @@ def logspace(
if isinstance(end, FloatLike):
end = sym_int(end)
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
default_complex_dtype = utils.corresponding_complex_dtype(
torch.get_default_dtype()
)
dtype = default_complex_dtype
_dtype = None # torch.linspace will update the correct dtype
else:
_dtype = torch.float64
assert not isinstance(base, complex) # for mypy
if base < 0:
raise NotImplementedError
ret = torch.linspace(
ret = torch.linspace( # type: ignore[misc]
start,
end,
steps, # type: ignore[arg-type]
dtype=torch.float64,
dtype=_dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)
return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value]
@overload

View File

@ -1004,7 +1004,7 @@ def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(1, args=(3, 1))
def sample_inputs_logpace(op, device, dtype, requires_grad, **kwargs):
def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
ends = (-3, 0, 1.2, 2, 4)
starts = (-2., 0, 1, 2, 4.3)
nsteps = (0, 1, 2, 4)
@ -11150,7 +11150,7 @@ op_db: List[OpInfo] = [
supports_out=True,
supports_autograd=False,
error_inputs_func=error_inputs_linspace,
sample_inputs_func=sample_inputs_logpace,
sample_inputs_func=sample_inputs_logspace,
skips=(
# Tests that assume input is a tensor or sequence of tensors
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),