Update caching of tensor arguments for nvFuser's fusion creation (#87860)

Previously nvFuser's fusion definition was cached based on concrete shape and strides of tensor inputs for simplicity and correctness. This PR changes Python's cache to check the number of dimensions, size-1 dimensions, and contiguity information based on given strides and shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87860
Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/ngimel
This commit is contained in:
Ivan Yashchuk
2022-11-02 09:29:20 +00:00
committed by PyTorch MergeBot
parent ccf6b558a4
commit 4a8382b58e
3 changed files with 53 additions and 14 deletions

View File

@ -368,7 +368,7 @@ class TestPrims(TestCase):
def test_nvfuser_executor_cached_noncontiguous(self, device):
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsMode
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
@ -376,16 +376,18 @@ class TestPrims(TestCase):
def func(a):
return torch.sigmoid(a)
with TorchRefsMode():
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
# First run to create the cache
execute(gm, a, executor="nvfuser")
execute(gm, a, executor="strictly_nvfuser")
# a.mT is noncontiguous, but it shouldn't affect correctness
expected = execute(gm, a.mT, executor="aten")
actual = execute(gm, a.mT, executor="nvfuser")
self.assertEqual(expected, actual)
for use_python_cache in [True, False]:
params = {"use_python_fusion_cache": use_python_cache}
actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params)
self.assertEqual(expected, actual)
def test_nvfuser_capability_context(self, device):
# This test is to ensure that the torch calls are replaced with refs
@ -506,7 +508,7 @@ class TestPrims(TestCase):
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsMode
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 4, device=device)
@ -519,7 +521,7 @@ class TestPrims(TestCase):
dd = torch.sqrt(d)
return torch.mul(aa, dd.digamma())
with TorchRefsMode():
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a, b, c)
expected = execute(gm, a, b, c, executor="aten")
@ -535,7 +537,7 @@ class TestPrims(TestCase):
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsMode
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 4, device=device)
@ -543,7 +545,7 @@ class TestPrims(TestCase):
def func(a):
return torch.digamma(a) # not supported by nvfuser
with TorchRefsMode():
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
with catch_warnings(record=True) as w: