Add factory functions to python frontend (#89230)

- Add `full` nvprim to support factory functions because the full reference uses `empty` and `fill` while we have a full factory function.
- Change `full_like` reference to call `full` to avoid defining another nvprim.
- Enable support for new_zeros to enable `cudnn_batch_norm` decomposition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89230
Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
Ryan Spring
2022-12-06 07:16:19 +00:00
committed by PyTorch MergeBot
parent e645771e95
commit 3c9431f505
6 changed files with 258 additions and 3 deletions

View File

@ -234,6 +234,46 @@ class TestPrims(TestCase):
partitions = partitioner.propose_partitions()
self.assertEqual(len(partitions), 1)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_full(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
def func1(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device),)
def func2(size, value, b):
a = torch.full(size, value, dtype=dtype, device=device)
b_sin = b.sin()
return (torch.add(a, b_sin),)
def func3(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device), b)
def func4(size, value, b):
b_sin = b.sin()
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
def func5(size, value, b):
b_sin = b.sin()
a = torch.full(size, value, dtype=dtype, device=device)
a_sin = a.sin()
return (a, b_sin, a_sin)
for func in (func1, func3, func2, func3, func4, func5):
size = (3, 3)
value = 10
b = torch.randn(*size, dtype=dtype, device=device)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(size, value, b)
out = execute(gm, size, value, b, executor="strictly_nvfuser")
self.assertEqual(out, func(size, value, b))
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_empty_fusion(self, device):
@ -687,7 +727,13 @@ class TestPrims(TestCase):
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
ref_out = gm(sample.input, *sample.args)
for idx, (left, right) in enumerate(zip(out, ref_out)):
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
if idx == 3:
self.assertTrue(torch.all(torch.eq(left, 0)))
else:
self.assertEqual(left, right)
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA