mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add factory functions to python frontend (#89230)
- Add `full` nvprim to support factory functions because the full reference uses `empty` and `fill` while we have a full factory function. - Change `full_like` reference to call `full` to avoid defining another nvprim. - Enable support for new_zeros to enable `cudnn_batch_norm` decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89230 Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
e645771e95
commit
3c9431f505
@ -234,6 +234,46 @@ class TestPrims(TestCase):
|
||||
partitions = partitioner.propose_partitions()
|
||||
self.assertEqual(len(partitions), 1)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_full(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
def func1(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device),)
|
||||
|
||||
def func2(size, value, b):
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
b_sin = b.sin()
|
||||
return (torch.add(a, b_sin),)
|
||||
|
||||
def func3(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b)
|
||||
|
||||
def func4(size, value, b):
|
||||
b_sin = b.sin()
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
|
||||
|
||||
def func5(size, value, b):
|
||||
b_sin = b.sin()
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
a_sin = a.sin()
|
||||
return (a, b_sin, a_sin)
|
||||
|
||||
for func in (func1, func3, func2, func3, func4, func5):
|
||||
size = (3, 3)
|
||||
value = 10
|
||||
b = torch.randn(*size, dtype=dtype, device=device)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(size, value, b)
|
||||
|
||||
out = execute(gm, size, value, b, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(size, value, b))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_empty_fusion(self, device):
|
||||
@ -687,7 +727,13 @@ class TestPrims(TestCase):
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
ref_out = gm(sample.input, *sample.args)
|
||||
for idx, (left, right) in enumerate(zip(out, ref_out)):
|
||||
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
|
||||
if idx == 3:
|
||||
self.assertTrue(torch.all(torch.eq(left, 0)))
|
||||
else:
|
||||
self.assertEqual(left, right)
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
|
Reference in New Issue
Block a user