From 4a8382b58eeca9eed09c7c3b801b81befc2f75ce Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 2 Nov 2022 09:29:20 +0000 Subject: [PATCH] Update caching of tensor arguments for nvFuser's fusion creation (#87860) Previously nvFuser's fusion definition was cached based on concrete shape and strides of tensor inputs for simplicity and correctness. This PR changes Python's cache to check the number of dimensions, size-1 dimensions, and contiguity information based on given strides and shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87860 Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/ngimel --- test/test_prims.py | 20 +++++++------ torch/_prims/nvfuser_executor.py | 29 +++++++++++++++---- .../cuda/python_frontend/python_bindings.cpp | 18 ++++++++++++ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/test/test_prims.py b/test/test_prims.py index 6223a34e0a3a..b6833352d0cf 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -368,7 +368,7 @@ class TestPrims(TestCase): def test_nvfuser_executor_cached_noncontiguous(self, device): # This test is to ensure that nvfuser computes correct results for noncontiguous tensors from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 3, device=device) @@ -376,16 +376,18 @@ class TestPrims(TestCase): def func(a): return torch.sigmoid(a) - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) # First run to create the cache - execute(gm, a, executor="nvfuser") + execute(gm, a, executor="strictly_nvfuser") # a.mT is noncontiguous, but it shouldn't affect correctness expected = execute(gm, a.mT, executor="aten") - actual = execute(gm, a.mT, executor="nvfuser") - self.assertEqual(expected, actual) + for use_python_cache in [True, False]: + params = {"use_python_fusion_cache": use_python_cache} + actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params) + self.assertEqual(expected, actual) def test_nvfuser_capability_context(self, device): # This test is to ensure that the torch calls are replaced with refs @@ -506,7 +508,7 @@ class TestPrims(TestCase): self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) @@ -519,7 +521,7 @@ class TestPrims(TestCase): dd = torch.sqrt(d) return torch.mul(aa, dd.digamma()) - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a, b, c) expected = execute(gm, a, b, c, executor="aten") @@ -535,7 +537,7 @@ class TestPrims(TestCase): self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) @@ -543,7 +545,7 @@ class TestPrims(TestCase): def func(a): return torch.digamma(a) # not supported by nvfuser - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) with catch_warnings(record=True) as w: diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py index 227e1847265b..ae9dbfff781d 100644 --- a/torch/_prims/nvfuser_executor.py +++ b/torch/_prims/nvfuser_executor.py @@ -40,8 +40,8 @@ DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType( # https://github.com/pytorch/pytorch/issues/80551 @dataclass(frozen=True) class nvFuserTensorTemplate: - size: tuple - stride: tuple + symbolic_shape: tuple + contiguity: tuple dtype: DataType is_cpu: bool @@ -51,12 +51,29 @@ class nvFuserScalarTemplate: dtype: DataType +@lru_cache(maxsize=2048) +def compute_symbolic_shape(shape): + """Computes the symbolic shape of a tensor. + nvFuser specializes on size-1 dimensions as broadcasted dimensions. + -1 is used to represent any size.""" + return tuple(1 if s == 1 else -1 for s in shape) + + +@lru_cache(maxsize=2048) +def compute_contiguity(shape, strides): + """Computes the contiguity information to simplify internal indexing. + Contiguous dimensions are represented by True, strided dimensions + are represented by False. + """ + return torch._C._nvfuser.compute_contiguity(shape, strides) + + def to_nvfuser_template_args(args): def to_nvfuser(arg): if isinstance(arg, torch.Tensor): return nvFuserTensorTemplate( - arg.size(), - arg.stride(), + compute_symbolic_shape(arg.size()), + compute_contiguity(arg.size(), arg.stride()), getnvFuserDtype(arg.dtype), arg.is_cpu, # type: ignore[attr-defined] ) @@ -163,7 +180,9 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): def templates_to_nvfuser_inputs(arg): if isinstance(arg, nvFuserTensorTemplate): - x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu) + x = fd.define_tensor( + arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu + ) return x elif isinstance(arg, nvFuserScalarTemplate): x = fd.define_scalar(arg.dtype) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index b633732f8926..12672d898598 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -40,6 +40,24 @@ void initNvFuserPythonBindings(PyObject* module) { .value("ComplexDouble", Nvf::DataType::ComplexDouble) .value("Null", Nvf::DataType::Null); + nvfuser.def( + "compute_contiguity", + [](const std::vector& sizes, + const std::vector& strides) { + py::tuple contiguity(sizes.size()); + TORCH_CHECK( + sizes.size() == strides.size(), + "compute_contiguity: Sizes and strides must have the same number of dimensions"); + if (sizes.size() == 0) { + return contiguity; + } + contiguity[sizes.size() - 1] = strides.back() == 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; --i) { + contiguity[i] = strides[i] == strides[i + 1] * sizes[i + 1]; + } + return contiguity; + }); + //! Binding the FusionCache that holds a cache of Fusions //! This is only bound to provide an interface to get the number of fusions //! that are cached.