mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update caching of tensor arguments for nvFuser's fusion creation (#87860)
Previously nvFuser's fusion definition was cached based on concrete shape and strides of tensor inputs for simplicity and correctness. This PR changes Python's cache to check the number of dimensions, size-1 dimensions, and contiguity information based on given strides and shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87860 Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ccf6b558a4
commit
4a8382b58e
@ -368,7 +368,7 @@ class TestPrims(TestCase):
|
||||
def test_nvfuser_executor_cached_noncontiguous(self, device):
|
||||
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
@ -376,15 +376,17 @@ class TestPrims(TestCase):
|
||||
def func(a):
|
||||
return torch.sigmoid(a)
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
# First run to create the cache
|
||||
execute(gm, a, executor="nvfuser")
|
||||
execute(gm, a, executor="strictly_nvfuser")
|
||||
|
||||
# a.mT is noncontiguous, but it shouldn't affect correctness
|
||||
expected = execute(gm, a.mT, executor="aten")
|
||||
actual = execute(gm, a.mT, executor="nvfuser")
|
||||
for use_python_cache in [True, False]:
|
||||
params = {"use_python_fusion_cache": use_python_cache}
|
||||
actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_nvfuser_capability_context(self, device):
|
||||
@ -506,7 +508,7 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
@ -519,7 +521,7 @@ class TestPrims(TestCase):
|
||||
dd = torch.sqrt(d)
|
||||
return torch.mul(aa, dd.digamma())
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, b, c)
|
||||
|
||||
expected = execute(gm, a, b, c, executor="aten")
|
||||
@ -535,7 +537,7 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsMode
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
@ -543,7 +545,7 @@ class TestPrims(TestCase):
|
||||
def func(a):
|
||||
return torch.digamma(a) # not supported by nvfuser
|
||||
|
||||
with TorchRefsMode():
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
with catch_warnings(record=True) as w:
|
||||
|
@ -40,8 +40,8 @@ DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
|
||||
# https://github.com/pytorch/pytorch/issues/80551
|
||||
@dataclass(frozen=True)
|
||||
class nvFuserTensorTemplate:
|
||||
size: tuple
|
||||
stride: tuple
|
||||
symbolic_shape: tuple
|
||||
contiguity: tuple
|
||||
dtype: DataType
|
||||
is_cpu: bool
|
||||
|
||||
@ -51,12 +51,29 @@ class nvFuserScalarTemplate:
|
||||
dtype: DataType
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def compute_symbolic_shape(shape):
|
||||
"""Computes the symbolic shape of a tensor.
|
||||
nvFuser specializes on size-1 dimensions as broadcasted dimensions.
|
||||
-1 is used to represent any size."""
|
||||
return tuple(1 if s == 1 else -1 for s in shape)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def compute_contiguity(shape, strides):
|
||||
"""Computes the contiguity information to simplify internal indexing.
|
||||
Contiguous dimensions are represented by True, strided dimensions
|
||||
are represented by False.
|
||||
"""
|
||||
return torch._C._nvfuser.compute_contiguity(shape, strides)
|
||||
|
||||
|
||||
def to_nvfuser_template_args(args):
|
||||
def to_nvfuser(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return nvFuserTensorTemplate(
|
||||
arg.size(),
|
||||
arg.stride(),
|
||||
compute_symbolic_shape(arg.size()),
|
||||
compute_contiguity(arg.size(), arg.stride()),
|
||||
getnvFuserDtype(arg.dtype),
|
||||
arg.is_cpu, # type: ignore[attr-defined]
|
||||
)
|
||||
@ -163,7 +180,9 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
|
||||
def templates_to_nvfuser_inputs(arg):
|
||||
if isinstance(arg, nvFuserTensorTemplate):
|
||||
x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu)
|
||||
x = fd.define_tensor(
|
||||
arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
|
||||
)
|
||||
return x
|
||||
elif isinstance(arg, nvFuserScalarTemplate):
|
||||
x = fd.define_scalar(arg.dtype)
|
||||
|
@ -40,6 +40,24 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
.value("ComplexDouble", Nvf::DataType::ComplexDouble)
|
||||
.value("Null", Nvf::DataType::Null);
|
||||
|
||||
nvfuser.def(
|
||||
"compute_contiguity",
|
||||
[](const std::vector<int64_t>& sizes,
|
||||
const std::vector<int64_t>& strides) {
|
||||
py::tuple contiguity(sizes.size());
|
||||
TORCH_CHECK(
|
||||
sizes.size() == strides.size(),
|
||||
"compute_contiguity: Sizes and strides must have the same number of dimensions");
|
||||
if (sizes.size() == 0) {
|
||||
return contiguity;
|
||||
}
|
||||
contiguity[sizes.size() - 1] = strides.back() == 1;
|
||||
for (int64_t i = static_cast<int64_t>(sizes.size()) - 2; i >= 0; --i) {
|
||||
contiguity[i] = strides[i] == strides[i + 1] * sizes[i + 1];
|
||||
}
|
||||
return contiguity;
|
||||
});
|
||||
|
||||
//! Binding the FusionCache that holds a cache of Fusions
|
||||
//! This is only bound to provide an interface to get the number of fusions
|
||||
//! that are cached.
|
||||
|
Reference in New Issue
Block a user