diff --git a/test/test_dispatch.py b/test/test_dispatch.py index ec9fd20797e3..45480d8916f0 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -229,11 +229,11 @@ class TestDispatch(TestCase): # m.impl("test_def", [](const Tensor& x) { return x }) lambda m: m.impl_t_t("foo"), # m.impl("test_def", kCPU, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", dispatch="cpu"), + lambda m: m.impl_t_t("foo", dispatch="CPU"), # m.impl("test_def", kAutograd, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", dispatch="autograd"), + lambda m: m.impl_t_t("foo", dispatch="Autograd"), # m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", dispatch="autogradcpu") + lambda m: m.impl_t_t("foo", dispatch="AutogradCPU") ]).state self.assertExpectedInline(state, '''\ name: test::foo @@ -262,11 +262,11 @@ catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu"), + lambda m: m.impl_t_t("foo", "CPU"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autograd"), + lambda m: m.impl_t_t("foo", "Autograd"), # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autogradcpu") + lambda m: m.impl_t_t("foo", "AutogradCPU") ]).state self.assertExpectedInline(state, '''\ name: test::foo @@ -296,11 +296,11 @@ alias analysis kind: FROM_SCHEMA # m.impl("foo", [](const Tensor& x) { return x }) lambda m: m.impl_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", "cpu"), + lambda m: m.impl_t_t("foo", "CPU"), # m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", "autograd"), + lambda m: m.impl_t_t("foo", "Autograd"), # m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x }) - lambda m: m.impl_t_t("foo", "autogradcpu") + lambda m: m.impl_t_t("foo", "AutogradCPU") ]).state self.assertExpectedInline(state, '''\ name: test::foo @@ -316,13 +316,13 @@ catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"), + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "xla", debug="fn_xla"), + lambda m: m.impl_t_t("foo", "XLA", debug="fn_xla"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"), + lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autogradcpu", debug="fn_autogradcpu") + lambda m: m.impl_t_t("foo", "AutogradCPU", debug="fn_autogradcpu") ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -351,12 +351,12 @@ AutogradXLA: fn_autograd [autograd kernel] ''') def test_computed_table_with_cpu_catchall(self): - global_m = C._dispatch_library("IMPL", "_", "autogradcpu") + global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu"), + lambda m: m.impl_t_t("foo", "CPU"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -382,12 +382,12 @@ AutogradXLA: default_def_name_t_t [catch all] ''') def test_computed_table_with_math(self): - global_m = C._dispatch_library("IMPL", "_", "autogradcpu") + global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "math"), + lambda m: m.impl_t_t("foo", "Math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -412,14 +412,14 @@ AutogradXLA: impl_t_t [math kernel] ''') def test_computed_table_with_cpu_math(self): - global_m = C._dispatch_library("IMPL", "_", "autogradcpu") + global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"), + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "math", debug="fn_math"), + lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -445,12 +445,12 @@ AutogradXLA: fn_math [math kernel] ''') def test_computed_table_with_autograd(self): - global_m = C._dispatch_library("IMPL", "_", "autogradcpu") + global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autograd"), + lambda m: m.impl_t_t("foo", "Autograd"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -476,11 +476,11 @@ AutogradXLA: impl_t_t [autograd kernel] # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"), + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"), + lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "math", debug="fn_math"), + lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -512,9 +512,9 @@ AutogradXLA: fn_math [math kernel] # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"), + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"), + lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -538,6 +538,39 @@ AutogradOther: fn_autograd [autograd kernel] AutogradCPU: fn_autograd [autograd kernel] AutogradCUDA: fn_autograd [autograd kernel] AutogradXLA: fn_autograd [autograd kernel] +''') + + def test_computed_table_with_ambiguous_autogradother(self): + result = self.commute("foo", [ + # m.def("foo", [](const Tensor & x) { return x }) + lambda m: m.def_name_t_t("foo"), + # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), + ]) + state, table = result.state, result.table + self.assertExpectedInline(state, '''\ +name: test::foo +schema: test::foo(Tensor _0) -> (Tensor _0) +debug: registered at /dev/null:0 +alias analysis kind: CONSERVATIVE +QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +''') + + # computed dispatch table is too big, so we only check on a few entries we're interested in. + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) + + self.assertExpectedInline(extracted_table, '''\ +CPU: fn_math [math kernel] +CUDA: fn_math [math kernel] +XLA: fn_math [math kernel] +AutogradOther: ambiguous_autogradother [ambiguous autogradother] +AutogradCPU: fn_math [math kernel] +AutogradCUDA: fn_math [math kernel] +AutogradXLA: fn_math [math kernel] ''') # Can't do this yet for BC reasons @@ -631,7 +664,7 @@ alias analysis kind: PURE_FUNCTION ) def test_multiple_fallback(self): - global_m = C._dispatch_library("IMPL", "_", "xla") + global_m = C._dispatch_library("IMPL", "_", "XLA") global_m.fallback_fallthrough(), try: global_m.fallback_fallthrough(), diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 21bf8e69adc4..f0f63bf7a2f0 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -27,12 +27,13 @@ torch::Library::Kind parseKind(const std::string& k) { c10::optional parseDispatchKey(const std::string& k) { static std::unordered_map key_map = { - {"cpu", c10::DispatchKey::CPU}, - {"cuda", c10::DispatchKey::CUDA}, - {"xla", c10::DispatchKey::XLA}, - {"math", c10::DispatchKey::Math}, - {"autograd", c10::DispatchKey::Autograd}, - {"autogradcpu", c10::DispatchKey::AutogradCPU}, + {"CPU", c10::DispatchKey::CPU}, + {"CUDA", c10::DispatchKey::CUDA}, + {"XLA", c10::DispatchKey::XLA}, + {"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, + {"Math", c10::DispatchKey::Math}, + {"Autograd", c10::DispatchKey::Autograd}, + {"AutogradCPU", c10::DispatchKey::AutogradCPU}, {"", c10::DispatchKey::Undefined}, }; auto it = key_map.find(k);