mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[privateuse1] _refs.masked_fill support privateuse1 when value.device.type is cpu (#124835)
_refs.masked_fill support privateuse1 when value.device.type is cpu. 1. maybe I should consider whether this modification meets the expectations of other privateuse1 devices, 2. add TestCase Fixes #124693 Co-authored-by: albanD <desmaison.alban@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124835 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
07422fd0b9
commit
af67704dcc
@ -682,6 +682,23 @@ def forward(self, x_1, start_1):
|
||||
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
|
||||
return convert_element_type""")
|
||||
|
||||
def test_masked_fill(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
if torch.device(device).type not in ["xpu", "cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
|
||||
|
||||
def func(scores, mask, value):
|
||||
return scores.masked_fill(mask, value)
|
||||
|
||||
scores_t = torch.tensor([1, 2, 3, 4], device=device)
|
||||
mask_t = torch.tensor([True, True, True, True], device=device)
|
||||
value_t = torch.tensor(0, dtype=scores_t.dtype)
|
||||
cfunc = make_fx(func, decomposition_table=decomposition_table)
|
||||
fx_g = cfunc(scores_t, mask_t, value_t)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, scores_1, mask_1, value_1):
|
||||
where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None
|
||||
return where""")
|
||||
|
||||
class DecompCrossRefMode(TorchDispatchMode):
|
||||
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
|
||||
|
Reference in New Issue
Block a user