Compare commits

...

1 Commits

Author SHA1 Message Date
028132ed9e [Inductor] Add decomposition for aten.mul 2025-06-09 23:13:23 -07:00
2 changed files with 36 additions and 0 deletions

View File

@ -1441,6 +1441,15 @@ class CommonTemplate:
self.common(fn, (x, y, 2))
def test_mul_complex(self):
def fn(a, b):
return torch.mul(a, b)
x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
self.common(fn, (x, y))
def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)

View File

@ -502,6 +502,33 @@ def add(
return result
@register_decomposition([aten.mul])
def mul(
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
# Require both x and y to be complex tensors.
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
if not x_is_complex_tensor or not y_is_complex_tensor:
return NotImplemented
complex_type = torch.promote_types(x.dtype, y.dtype)
x_real = x.real
x_imag = x.imag
y_real = y.real
y_imag = y.imag
real = x_real * y_real - x_imag * y_imag
imag = x_real * y_imag + x_imag * y_real
result = torch.flatten(torch.stack([real, imag], dim=-1), start_dim=-2).view(
complex_type
)
return result
@register_decomposition([aten.conj_physical])
def conj_physical(self: torch.Tensor) -> torch.Tensor:
assert not self.is_complex(), "TODO: implement this"