mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nvprims rand_like support for Dropout (#85077)
NM Pull Request resolved: https://github.com/pytorch/pytorch/pull/85077 Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e4c88518c
commit
c7b17d7eb1
@ -176,6 +176,24 @@ class TestPrims(TestCase):
|
||||
out = execute(gm, a, a, a, executor="nvfuser")
|
||||
self.assertEqual(out, (a, a, a))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_rand_like_fusion(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.rand_like(a)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
out = execute(gm, a, executor="strictly_nvfuser")
|
||||
self.assertEqual(out.size(), a.size())
|
||||
|
||||
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
|
||||
Reference in New Issue
Block a user