mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nvFuser support for torch.Tensor.view (#84634)
This is an alternative to https://github.com/pytorch/pytorch/pull/83739. While PrimTorch has `view` as a reference, we would like to use nvFuser's implementation for `view` for now. Later we might transition to PrimTorch's `torch._refs.view`. See `test_nvprims_view` for examples of things that are now sent to nvFuser. Note that nvFuser's `view` is a copy-like operation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84634 Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
b48deedb77
commit
fd80684784
@ -688,6 +688,53 @@ class TestPrims(TestCase):
|
||||
)
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((3, 4, 5))
|
||||
|
||||
def func1(a):
|
||||
return a.view(tuple(reversed(a.shape)))
|
||||
|
||||
def func2(a):
|
||||
return a.reshape(tuple(reversed(a.shape)))
|
||||
|
||||
def func3(a):
|
||||
return torch.view_copy(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func4(a):
|
||||
return torch.reshape(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func5(a):
|
||||
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func6(a):
|
||||
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func7(a):
|
||||
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
for func in (func1, func2, func3, func4, func5, func6, func7):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_nvprims_view = any(
|
||||
torch.ops.nvprims.view.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_view)
|
||||
|
||||
# Try executing the graph
|
||||
out = execute(gm, a, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(a))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
|
Reference in New Issue
Block a user