catches failure on nvprim speculative lowering (#85580)

Fixes #85517

Added a try/catch exception during tracing `get_isolated_graphmodule` inside `_is_func_unsupported_nvfuser`. Stops speculative lowering to nvprim when query errors out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85580
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
This commit is contained in:
jjsjann123
2022-09-29 15:22:45 +00:00
committed by PyTorch MergeBot
parent a807f1987a
commit cab6ffa0f7
2 changed files with 39 additions and 21 deletions

View File

@ -33,6 +33,7 @@ if TEST_SCIPY:
import scipy.special
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
class TestPrims(TestCase):
@onlyCUDA
@ -778,6 +779,32 @@ class TestDecomp(TestCase):
)
self.assertFalse(includes_aten_to_copy)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
# masked_fill decomposition extracts cpu scalar tensor value when
# filling out a cuda tensor. This triggers data-dependent control flow
# on TorchRefsNvfuser speculative lowering.
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
x = torch.empty(2, 3, device=device).to(dtype=dtype)
mask = torch.ones_like(x).bool()
y = torch.tensor(0.3) # cpu scalar tensor
def func(x, mask, y):
return torch.masked_fill(x, mask, y)
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
gm = make_fx(func, decomposition_table={})(x, mask, y)
with warnings.catch_warnings(record=True) as caught:
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(x, mask, y)
# masked_fill decomposition fails inside `get_isolated_graphmodule`
self.assertTrue(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
def test_decomposition_method_vararg(self, device, dtype, op):
# some ops have vararg variants for the methods. this tests it.

View File

@ -1,6 +1,6 @@
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Sequence, Union
from typing import Any, Callable, Dict, Sequence
from warnings import warn
import torch
@ -68,25 +68,6 @@ def torch_to_refs_map():
return r
@functools.lru_cache(None)
def nvfuser_decomp_table():
"""
decomposition table needed for nvfuser
"""
aten = torch.ops.aten
nvfuser_decompositions: Sequence[
Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
] = { # type: ignore[assignment]
# AMP calls `to` in C++, which is not handled by torch mapping
aten._to_copy,
}
from torch._decomp import get_decompositions
decomp_table = get_decompositions(nvfuser_decompositions)
return decomp_table
@functools.lru_cache(None)
def all_prims():
"""
@ -203,7 +184,17 @@ def _is_node_supported_nvfuser(node):
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
with torch_function_mode:
gm = get_isolated_graphmodule(func, args, kwargs)
try:
gm = get_isolated_graphmodule(func, args, kwargs)
except Exception as e:
warn(
"get_isolated_graphmodule failed on decomposition: "
+ func.__name__
+ " with error message: "
+ str(e)
)
# returns unsupported when tracing fails.
return True
supported_ops = NvfuserPrimOperatorSupport()
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)