mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Python Dispatcher should respect FuncTorchBatchedDecomposition key (#98328)
Fixes https://github.com/pytorch/pytorch/issues/97425. Python Dispatcher's resolve_key function should be equivalent to computeDispatchTableEntryWithDebug. We added a section to computeDispatchTableEntryWithDebug but forgot to add it to resolve_key. This PR fixes that discrepancy. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/98328 Approved by: https://github.com/Chillee, https://github.com/kshitij12345, https://github.com/Neilblaze
This commit is contained in:
committed by
PyTorch MergeBot
parent
78e991e575
commit
f21a176c03
@ -45,6 +45,7 @@ from common_utils import (
|
||||
compute_quantities_for_vmap_test,
|
||||
is_valid_inplace_sample_input,
|
||||
decorate,
|
||||
DisableVmapFallback,
|
||||
)
|
||||
import types
|
||||
from collections import namedtuple
|
||||
@ -1057,6 +1058,16 @@ class TestVmapAPI(TestCase):
|
||||
expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_decomposition_under_python_dispatcher(self):
|
||||
# This test will raise an error if the vmap fallback gets invoked.
|
||||
# Here we test that decomps registered to FuncTorchBatchedDecomposition
|
||||
# are respected by the Python Dispatcher.
|
||||
t = torch.ones(3, 3) * 5
|
||||
with DisableVmapFallback():
|
||||
with torch._dispatch.python.enable_python_dispatcher():
|
||||
o = torch.vmap(torch.square)(t)
|
||||
self.assertEqual(o, torch.square(t))
|
||||
|
||||
def _test_vmap_autocast(self, device):
|
||||
|
||||
if torch.device(device).type == "cpu":
|
||||
|
@ -176,6 +176,10 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
|
||||
return cand
|
||||
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
|
||||
cand = DispatchKey.Autograd
|
||||
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
|
||||
return cand
|
||||
# 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
|
||||
cand = DispatchKey.FuncTorchBatchedDecomposition
|
||||
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
|
||||
return cand
|
||||
# Backend fallback
|
||||
|
@ -511,6 +511,10 @@ void initDispatchBindings(PyObject* module) {
|
||||
DEF_ONE(Python)
|
||||
DEF_ONE(FuncTorchDynamicLayerFrontMode)
|
||||
DEF_ONE(FuncTorchDynamicLayerBackMode)
|
||||
DEF_ONE(FuncTorchBatchedDecomposition)
|
||||
DEF_ONE(FuncTorchBatched)
|
||||
DEF_ONE(FuncTorchVmapMode)
|
||||
DEF_ONE(FuncTorchGradWrapper)
|
||||
DEF_ONE(PythonDispatcher)
|
||||
DEF_ONE(Functionalize)
|
||||
DEF_ONE(AutocastCPU)
|
||||
|
@ -97,6 +97,9 @@ class DispatchKey(Enum):
|
||||
Autocast = auto()
|
||||
Batched = auto()
|
||||
VmapMode = auto()
|
||||
FuncTorchGradWrapper = auto()
|
||||
FuncTorchBatched = auto()
|
||||
FuncTorchVmapMode = auto()
|
||||
FuncTorchDynamicLayerFrontMode = auto()
|
||||
Functionalize = auto()
|
||||
TESTING_ONLY_GenericWrapper = auto()
|
||||
@ -108,6 +111,7 @@ class DispatchKey(Enum):
|
||||
CompositeImplicitAutogradNestedTensor = auto()
|
||||
CompositeExplicitAutograd = auto()
|
||||
CompositeExplicitAutogradNonFunctional = auto()
|
||||
FuncTorchBatchedDecomposition = auto()
|
||||
|
||||
# BEGIN autogenerated
|
||||
CPU = auto()
|
||||
|
Reference in New Issue
Block a user