Add a way to skip lowering to nvprims (#85811)

This PR adds `skip_ops` argument to `TorchRefsNvfuserCapabilityMode` and `NvfuserPrimsMode` which is an iterable of function names to be skipped in the translation to nvprims process.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85811
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
This commit is contained in:
Ivan Yashchuk
2022-09-30 12:01:45 +00:00
committed by PyTorch MergeBot
parent 787028cadb
commit b00a5359f7
3 changed files with 103 additions and 8 deletions

View File

@ -27,6 +27,7 @@ from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs,
import torch._prims as prims
from torch._prims.executor import make_traced
import torch._refs as refs
from torch.fx.experimental.proxy_tensor import make_fx
if TEST_SCIPY:
@ -165,6 +166,55 @@ class TestPrims(TestCase):
), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ",
"while there exists nvfuser implementations for them.")
def test_skip_ops_nvfuser_prims_mode(self, device):
# This test verifies that the NvfuserPrimsMode skips the specified
# functions. Skipping a function means that it's not converted into
# nvprims counterparts.
from torch._prims.context import NvfuserPrimsMode
a = make_tensor(5, 5, device=device, dtype=torch.float32)
def func(a):
return torch.ops.prims.sin.default(a)
skip_ops = {"prims.sin.default", }
with NvfuserPrimsMode(skip_ops=skip_ops):
gm = make_fx(func)(a)
includes_any_prims_sin = any(
node.target == torch.ops.prims.sin.default for node in gm.graph.nodes
)
self.assertTrue(includes_any_prims_sin)
include_any_nvprims_sin = any(
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
)
self.assertFalse(include_any_nvprims_sin)
def test_skip_ops_nvfuser_capability_mode(self, device):
# This test verifies that the NvfuserCapabilityMode skips the specified
# functions. Skipping a function means that specific
# reference/decomposition is not traced and there's no attempt to lower
# it to nvprims.
from torch._prims.context import TorchRefsNvfuserCapabilityMode
a = make_tensor(5, 5, device=device, dtype=torch.float32)
def func(a):
return torch.sin(a)
skip_ops = {"torch.sin", }
with TorchRefsNvfuserCapabilityMode(skip_ops=skip_ops):
gm = make_fx(func)(a)
includes_any_aten_sin = any(
node.target == torch.ops.aten.sin.default for node in gm.graph.nodes
)
self.assertTrue(includes_any_aten_sin)
include_any_nvprims_sin = any(
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
)
self.assertFalse(include_any_nvprims_sin)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_empty_fusion(self, device):
@ -748,10 +798,11 @@ class TestDecomp(TestCase):
from torch._prims.context import TorchRefsNvfuserCapabilityMode, _is_func_unsupported_nvfuser
from torch.fx.experimental.proxy_tensor import make_fx
op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default)
op = torch.ops.aten.leaky_relu_backward.default
op_decomp = torch._decomp.decomposition_table.get(op)
def fn0(*arg):
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, arg, {})
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, op_decomp, arg, {})
def fn1(x):
x = x * 2

View File

@ -87,8 +87,19 @@ class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
By default, this context manager will fall back on the torch.ops.prims* if the
nvprim does not exist.
It's possible to skip certain prims by passing their names to the skip_ops
argument. skip_ops is expected to be a sequence of strings, e.g.,
["prims.add.default"] In order to check the expected name of a prim, one can
use the `torch.overrides.resolve_name`.
>>> # xdoctest: +SKIP("undefined vars")
>>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
... torch.ops.prims.add.default(x, y) # does not call torch.ops.nvprims.add.default(x, y)
"""
def __init__(self, *, skip_ops=()):
self.skip_ops = skip_ops
def __torch_function__(
self,
orig_func: Callable,
@ -98,6 +109,12 @@ class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
):
if kwargs is None:
kwargs = {}
# If the function is in the skip list, then we don't want to
# remap it to the nvprims.
if torch.overrides.resolve_name(orig_func) in self.skip_ops:
return orig_func(*args, **kwargs)
if isinstance(orig_func, torch._ops.OpOverload) or isinstance(
orig_func, torch._ops.OpOverloadPacket
):
@ -163,7 +180,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
if func is not None:
# If the ref exists query whether we should use it or not
if self.should_fallback_fn(self, func, args, kwargs):
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with self:
@ -182,7 +199,32 @@ def _is_node_supported_nvfuser(node):
)
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
def _is_func_unsupported_nvfuser(
torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
):
"""
This function traces the `func` under `torch_function_mode` and checks if
any of the traced nodes are not supported by nvFuser. If so, we should
fallback to the original function.
`skip_ops` argument is expected to be a list of strings of function names
that would match with `torch.overrides.resolve_name`.
Args:
torch_function_mode: The torch_function_mode context manager. orig_func:
The original function, its name will be used to check if
it should be skipped.
func: The function to be traced. args: The args to be passed to the
function. kwargs: The kwargs to be passed to the function.
Keyword args:
skip_ops: A list of ops to skip when checking if the function is
supported.
"""
# One supported case is easy to check: if the resolved name of the original
# function in the skip list, skip it.
if torch.overrides.resolve_name(orig_func) in skip_ops:
return True
with torch_function_mode:
try:
gm = get_isolated_graphmodule(func, args, kwargs)
@ -205,11 +247,13 @@ def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
def __init__(self):
def __init__(self, *, skip_ops=()):
super().__init__(
strict=False,
should_fallback_fn=_is_func_unsupported_nvfuser,
prims_mode_cls=NvfuserPrimsMode,
should_fallback_fn=functools.partial(
_is_func_unsupported_nvfuser, skip_ops=skip_ops
),
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
)
def _is_var_mean(self, func):

View File

@ -1695,7 +1695,7 @@ def resolve_name(f):
Name of the function; if eval'ed it should give back the input
function.
"""
if isinstance(f, torch._ops.OpOverload):
if isinstance(f, torch._ops.OpOverload) or isinstance(f, torch._ops.OpOverloadPacket):
return str(f)
return _get_overridable_functions()[1].get(f)