catches failure on nvprim speculative lowering (#85580)

Fixes #85517

Added a try/catch exception during tracing `get_isolated_graphmodule` inside `_is_func_unsupported_nvfuser`. Stops speculative lowering to nvprim when query errors out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85580
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
This commit is contained in:
jjsjann123
2022-09-29 15:22:45 +00:00
committed by PyTorch MergeBot
parent a807f1987a
commit cab6ffa0f7
2 changed files with 39 additions and 21 deletions

View File

@ -33,6 +33,7 @@ if TEST_SCIPY:
import scipy.special
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
class TestPrims(TestCase):
@onlyCUDA
@ -778,6 +779,32 @@ class TestDecomp(TestCase):
)
self.assertFalse(includes_aten_to_copy)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
# masked_fill decomposition extracts cpu scalar tensor value when
# filling out a cuda tensor. This triggers data-dependent control flow
# on TorchRefsNvfuser speculative lowering.
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
x = torch.empty(2, 3, device=device).to(dtype=dtype)
mask = torch.ones_like(x).bool()
y = torch.tensor(0.3) # cpu scalar tensor
def func(x, mask, y):
return torch.masked_fill(x, mask, y)
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
gm = make_fx(func, decomposition_table={})(x, mask, y)
with warnings.catch_warnings(record=True) as caught:
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(x, mask, y)
# masked_fill decomposition fails inside `get_isolated_graphmodule`
self.assertTrue(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
def test_decomposition_method_vararg(self, device, dtype, op):
# some ops have vararg variants for the methods. this tests it.