Add nvFuser support for torch.Tensor.view (#84634)

This is an alternative to https://github.com/pytorch/pytorch/pull/83739. While PrimTorch has `view` as a reference, we would like to use nvFuser's implementation for `view` for now. Later we might transition to PrimTorch's `torch._refs.view`.

See `test_nvprims_view` for examples of things that are now sent to nvFuser. Note that nvFuser's `view` is a copy-like operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84634
Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
Ivan Yashchuk
2022-10-14 12:08:02 +00:00
committed by PyTorch MergeBot
parent b48deedb77
commit fd80684784
7 changed files with 248 additions and 5 deletions

View File

@ -688,6 +688,53 @@ class TestPrims(TestCase):
)
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_nvprims_view(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((3, 4, 5))
def func1(a):
return a.view(tuple(reversed(a.shape)))
def func2(a):
return a.reshape(tuple(reversed(a.shape)))
def func3(a):
return torch.view_copy(a, tuple(reversed(a.shape)))
def func4(a):
return torch.reshape(a, tuple(reversed(a.shape)))
def func5(a):
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
def func6(a):
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
def func7(a):
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
for func in (func1, func2, func3, func4, func5, func6, func7):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_nvprims_view = any(
torch.ops.nvprims.view.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_view)
# Try executing the graph
out = execute(gm, a, executor="strictly_nvfuser")
self.assertEqual(out, func(a))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float16)