Fix empty matrix handling of addmv in inductor (#143792)

This is a resubmission of my previous PR that I accidentally deleted, apology in advance if any inconvenience caused. Below are details of this PR.

Fix an issue when torch.addmv behaves inconsistent between torch.compile mode and eager mode. Here is the code to reproduce:

```
import torch
import numpy as np

@torch.compile
def test_optimized(input, mat, vec):
    return torch.addmv(input, mat, vec)

def test(input, mat, vec):
    return torch.addmv(input, mat, vec)

input = torch.tensor([2], dtype=torch.int32)
mat = torch.tensor(np.random.randn(0, 0), dtype=torch.int32)
vec = torch.tensor([])
origin_out = test(input, mat, vec)
optimized_out = test_optimized(input, mat, vec)
print(origin_out)  # tensor([2.])
print(optimized_out)  # tensor([])
```

According to the equation (https://pytorch.org/docs/stable/generated/torch.addmv.html), when matrix and vector is empty, returning `[2.]` seems more reasonable to me.

Following the cpu implementation of this API:e97b97af56/aten/src/ATen/native/Blas.cpp (L62)

I add an additional branch to handle empty matrix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143792
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
maybeLee
2025-03-06 02:09:22 +00:00
committed by PyTorch MergeBot
parent 38b3375a81
commit 43e1284c96
2 changed files with 13 additions and 0 deletions

View File

@ -3556,6 +3556,17 @@ class CommonTemplate:
),
)
def test_addmv(self):
def fn(a, b, c):
return torch.addmv(a, b, c)
cfn = torch.compile(backend="inductor")(fn)
input = torch.tensor([2], dtype=torch.int32)
mat = torch.tensor(np.random.randn(0, 0), dtype=torch.int32)
vec = torch.tensor([])
with torch.no_grad():
self.assertEqual(cfn(input, mat, vec), fn(input, mat, vec))
# https://github.com/pytorch/pytorch/issues/98979
@skipCUDAIf(True, "cuda failed for float64 linear")
@skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")

View File

@ -1508,6 +1508,8 @@ def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1
out = alpha * torch.mv(mat1, vec)
if beta == 0:
return out
if out.numel() == 0: # handle empty matrix
return beta * self
return out + beta * self