mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix empty matrix handling of addmv in inductor (#143792)
This is a resubmission of my previous PR that I accidentally deleted, apology in advance if any inconvenience caused. Below are details of this PR.
Fix an issue when torch.addmv behaves inconsistent between torch.compile mode and eager mode. Here is the code to reproduce:
```
import torch
import numpy as np
@torch.compile
def test_optimized(input, mat, vec):
return torch.addmv(input, mat, vec)
def test(input, mat, vec):
return torch.addmv(input, mat, vec)
input = torch.tensor([2], dtype=torch.int32)
mat = torch.tensor(np.random.randn(0, 0), dtype=torch.int32)
vec = torch.tensor([])
origin_out = test(input, mat, vec)
optimized_out = test_optimized(input, mat, vec)
print(origin_out) # tensor([2.])
print(optimized_out) # tensor([])
```
According to the equation (https://pytorch.org/docs/stable/generated/torch.addmv.html), when matrix and vector is empty, returning `[2.]` seems more reasonable to me.
Following the cpu implementation of this API:e97b97af56/aten/src/ATen/native/Blas.cpp (L62)
I add an additional branch to handle empty matrix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143792
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
38b3375a81
commit
43e1284c96
@ -3556,6 +3556,17 @@ class CommonTemplate:
|
||||
),
|
||||
)
|
||||
|
||||
def test_addmv(self):
|
||||
def fn(a, b, c):
|
||||
return torch.addmv(a, b, c)
|
||||
|
||||
cfn = torch.compile(backend="inductor")(fn)
|
||||
input = torch.tensor([2], dtype=torch.int32)
|
||||
mat = torch.tensor(np.random.randn(0, 0), dtype=torch.int32)
|
||||
vec = torch.tensor([])
|
||||
with torch.no_grad():
|
||||
self.assertEqual(cfn(input, mat, vec), fn(input, mat, vec))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/98979
|
||||
@skipCUDAIf(True, "cuda failed for float64 linear")
|
||||
@skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")
|
||||
|
@ -1508,6 +1508,8 @@ def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1
|
||||
out = alpha * torch.mv(mat1, vec)
|
||||
if beta == 0:
|
||||
return out
|
||||
if out.numel() == 0: # handle empty matrix
|
||||
return beta * self
|
||||
return out + beta * self
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user