mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix aten.logspace decomposition (#105201)
Fixes #104118 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105201 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5afc2f5069
commit
0ad93a3d56
@ -1212,6 +1212,16 @@ class TestRefs(TestCase):
|
||||
expect = torch.unbind(a, 1)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
def test_logspace_with_complex_input(self):
|
||||
actual = refs.logspace(2, 10 + 5j, steps=5)
|
||||
expect = torch.logspace(2, 10 + 5j, steps=5)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
def test_linspace_with_complex_input(self):
|
||||
actual = refs.linspace(2, 10 + 5j, steps=5)
|
||||
expect = torch.linspace(2, 10 + 5j, steps=5)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestRefs, globals())
|
||||
|
||||
|
@ -4619,20 +4619,29 @@ def logspace(
|
||||
if isinstance(end, FloatLike):
|
||||
end = sym_int(end)
|
||||
|
||||
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
default_complex_dtype = utils.corresponding_complex_dtype(
|
||||
torch.get_default_dtype()
|
||||
)
|
||||
dtype = default_complex_dtype
|
||||
_dtype = None # torch.linspace will update the correct dtype
|
||||
else:
|
||||
_dtype = torch.float64
|
||||
|
||||
assert not isinstance(base, complex) # for mypy
|
||||
if base < 0:
|
||||
raise NotImplementedError
|
||||
ret = torch.linspace(
|
||||
ret = torch.linspace( # type: ignore[misc]
|
||||
start,
|
||||
end,
|
||||
steps, # type: ignore[arg-type]
|
||||
dtype=torch.float64,
|
||||
dtype=_dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)
|
||||
return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value]
|
||||
|
||||
|
||||
@overload
|
||||
|
@ -1004,7 +1004,7 @@ def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
|
||||
yield SampleInput(1, args=(3, 1))
|
||||
|
||||
|
||||
def sample_inputs_logpace(op, device, dtype, requires_grad, **kwargs):
|
||||
def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
|
||||
ends = (-3, 0, 1.2, 2, 4)
|
||||
starts = (-2., 0, 1, 2, 4.3)
|
||||
nsteps = (0, 1, 2, 4)
|
||||
@ -11150,7 +11150,7 @@ op_db: List[OpInfo] = [
|
||||
supports_out=True,
|
||||
supports_autograd=False,
|
||||
error_inputs_func=error_inputs_linspace,
|
||||
sample_inputs_func=sample_inputs_logpace,
|
||||
sample_inputs_func=sample_inputs_logspace,
|
||||
skips=(
|
||||
# Tests that assume input is a tensor or sequence of tensors
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
|
||||
|
Reference in New Issue
Block a user