[privateuse1] _refs.masked_fill support privateuse1 when value.device.type is cpu (#124835)

_refs.masked_fill support privateuse1 when value.device.type is cpu.

1. maybe I should consider whether this modification meets the expectations of other privateuse1 devices,
2. add TestCase

Fixes #124693

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124835
Approved by: https://github.com/albanD
This commit is contained in:
mashaobin
2024-05-01 18:57:14 +00:00
committed by PyTorch MergeBot
parent 07422fd0b9
commit af67704dcc
2 changed files with 21 additions and 1 deletions

View File

@ -682,6 +682,23 @@ def forward(self, x_1, start_1):
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
return convert_element_type""")
def test_masked_fill(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
if torch.device(device).type not in ["xpu", "cuda", torch._C._get_privateuse1_backend_name()]:
self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
def func(scores, mask, value):
return scores.masked_fill(mask, value)
scores_t = torch.tensor([1, 2, 3, 4], device=device)
mask_t = torch.tensor([True, True, True, True], device=device)
value_t = torch.tensor(0, dtype=scores_t.dtype)
cfunc = make_fx(func, decomposition_table=decomposition_table)
fx_g = cfunc(scores_t, mask_t, value_t)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, scores_1, mask_1, value_1):
where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None
return where""")
class DecompCrossRefMode(TorchDispatchMode):
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):