mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update caching of tensor arguments for nvFuser's fusion creation (#87860)
Previously nvFuser's fusion definition was cached based on concrete shape and strides of tensor inputs for simplicity and correctness. This PR changes Python's cache to check the number of dimensions, size-1 dimensions, and contiguity information based on given strides and shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87860 Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ccf6b558a4
commit
4a8382b58e
@ -368,7 +368,7 @@ class TestPrims(TestCase):
|
||||
def test_nvfuser_executor_cached_noncontiguous(self, device):
|
||||
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
@ -376,16 +376,18 @@ class TestPrims(TestCase):
|
||||
def func(a):
|
||||
return torch.sigmoid(a)
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
# First run to create the cache
|
||||
execute(gm, a, executor="nvfuser")
|
||||
execute(gm, a, executor="strictly_nvfuser")
|
||||
|
||||
# a.mT is noncontiguous, but it shouldn't affect correctness
|
||||
expected = execute(gm, a.mT, executor="aten")
|
||||
actual = execute(gm, a.mT, executor="nvfuser")
|
||||
self.assertEqual(expected, actual)
|
||||
for use_python_cache in [True, False]:
|
||||
params = {"use_python_fusion_cache": use_python_cache}
|
||||
actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_nvfuser_capability_context(self, device):
|
||||
# This test is to ensure that the torch calls are replaced with refs
|
||||
@ -506,7 +508,7 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
@ -519,7 +521,7 @@ class TestPrims(TestCase):
|
||||
dd = torch.sqrt(d)
|
||||
return torch.mul(aa, dd.digamma())
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, b, c)
|
||||
|
||||
expected = execute(gm, a, b, c, executor="aten")
|
||||
@ -535,7 +537,7 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
@ -543,7 +545,7 @@ class TestPrims(TestCase):
|
||||
def func(a):
|
||||
return torch.digamma(a) # not supported by nvfuser
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
with catch_warnings(record=True) as w:
|
||||
|
||||
Reference in New Issue
Block a user