Python Dispatcher should respect FuncTorchBatchedDecomposition key (#98328)

Fixes https://github.com/pytorch/pytorch/issues/97425.

Python Dispatcher's resolve_key function should be equivalent to
computeDispatchTableEntryWithDebug. We added a section to
computeDispatchTableEntryWithDebug but forgot to add it to resolve_key.

This PR fixes that discrepancy.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98328
Approved by: https://github.com/Chillee, https://github.com/kshitij12345, https://github.com/Neilblaze
This commit is contained in:
Richard Zou
2023-04-04 12:37:42 -07:00
committed by PyTorch MergeBot
parent 78e991e575
commit f21a176c03
4 changed files with 23 additions and 0 deletions

View File

@ -45,6 +45,7 @@ from common_utils import (
compute_quantities_for_vmap_test,
is_valid_inplace_sample_input,
decorate,
DisableVmapFallback,
)
import types
from collections import namedtuple
@ -1057,6 +1058,16 @@ class TestVmapAPI(TestCase):
expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
self.assertEqual(out, expected)
def test_decomposition_under_python_dispatcher(self):
# This test will raise an error if the vmap fallback gets invoked.
# Here we test that decomps registered to FuncTorchBatchedDecomposition
# are respected by the Python Dispatcher.
t = torch.ones(3, 3) * 5
with DisableVmapFallback():
with torch._dispatch.python.enable_python_dispatcher():
o = torch.vmap(torch.square)(t)
self.assertEqual(o, torch.square(t))
def _test_vmap_autocast(self, device):
if torch.device(device).type == "cpu":

View File

@ -176,6 +176,10 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
return cand
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
cand = DispatchKey.Autograd
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
return cand
# 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
cand = DispatchKey.FuncTorchBatchedDecomposition
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
return cand
# Backend fallback

View File

@ -511,6 +511,10 @@ void initDispatchBindings(PyObject* module) {
DEF_ONE(Python)
DEF_ONE(FuncTorchDynamicLayerFrontMode)
DEF_ONE(FuncTorchDynamicLayerBackMode)
DEF_ONE(FuncTorchBatchedDecomposition)
DEF_ONE(FuncTorchBatched)
DEF_ONE(FuncTorchVmapMode)
DEF_ONE(FuncTorchGradWrapper)
DEF_ONE(PythonDispatcher)
DEF_ONE(Functionalize)
DEF_ONE(AutocastCPU)

View File

@ -97,6 +97,9 @@ class DispatchKey(Enum):
Autocast = auto()
Batched = auto()
VmapMode = auto()
FuncTorchGradWrapper = auto()
FuncTorchBatched = auto()
FuncTorchVmapMode = auto()
FuncTorchDynamicLayerFrontMode = auto()
Functionalize = auto()
TESTING_ONLY_GenericWrapper = auto()
@ -108,6 +111,7 @@ class DispatchKey(Enum):
CompositeImplicitAutogradNestedTensor = auto()
CompositeExplicitAutograd = auto()
CompositeExplicitAutogradNonFunctional = auto()
FuncTorchBatchedDecomposition = auto()
# BEGIN autogenerated
CPU = auto()