mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[WIP] Move catchAll to Math (#45939)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45939 Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D24165890 Pulled By: ailzhang fbshipit-source-id: 72fe71ea95a738251b2fafc9eea4ab3831cf426b
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							d1ca7ef33e
						
					
				
				
					commit
					8c629ecc9a
				
			@ -105,7 +105,8 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
 | 
			
		||||
 | 
			
		||||
  // Add the kernel to the kernels list,
 | 
			
		||||
  // possibly creating the list if this is the first kernel.
 | 
			
		||||
  auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : catchAllKernel_;
 | 
			
		||||
  // Redirect catchAll registrations to Math.
 | 
			
		||||
  auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math];
 | 
			
		||||
 | 
			
		||||
  if (k.size() > 0) {
 | 
			
		||||
    TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
 | 
			
		||||
@ -132,8 +133,9 @@ void OperatorEntry::deregisterKernel_(
 | 
			
		||||
  c10::optional<DispatchKey> dispatch_key,
 | 
			
		||||
  std::list<AnnotatedKernel>::iterator kernel
 | 
			
		||||
) {
 | 
			
		||||
  if (dispatch_key.has_value()) {
 | 
			
		||||
    auto found = kernels_.find(*dispatch_key);
 | 
			
		||||
  // Redirect catchAll deregistrations to Math.
 | 
			
		||||
  DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::Math;
 | 
			
		||||
  auto found = kernels_.find(dk);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
 | 
			
		||||
  auto& k = found->second;
 | 
			
		||||
  k.erase(kernel);
 | 
			
		||||
@ -141,11 +143,7 @@ void OperatorEntry::deregisterKernel_(
 | 
			
		||||
    // the invariant says we don't want empty lists but instead remove the list from the map
 | 
			
		||||
    kernels_.erase(found);
 | 
			
		||||
  }
 | 
			
		||||
    updateDispatchTable_(dispatcher, *dispatch_key);
 | 
			
		||||
  } else {
 | 
			
		||||
    catchAllKernel_.erase(kernel);
 | 
			
		||||
    updateDispatchTableFull_(dispatcher);
 | 
			
		||||
  }
 | 
			
		||||
  updateDispatchTable_(dispatcher, dk);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
 | 
			
		||||
@ -259,6 +257,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
 | 
			
		||||
    //      fit 2.1 and we can remove 2.4 entirely.
 | 
			
		||||
    if (!has_backend_kernel && !catchAllKernel_.empty()) {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
 | 
			
		||||
      // Prepare for catchAll removal, make sure it's not used in dispatchTable
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(false);
 | 
			
		||||
      return {catchAllKernel_.front(), "catch all"};
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -272,6 +272,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
 | 
			
		||||
  // 4. Catch all
 | 
			
		||||
  if (!catchAllKernel_.empty()) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
 | 
			
		||||
    // Prepare for catchAll removal, make sure it's not used in dispatchTable
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(false);
 | 
			
		||||
    return {catchAllKernel_.front(), "catch all"};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -777,22 +777,6 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKer
 | 
			
		||||
  EXPECT_TRUE(called);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
 | 
			
		||||
  auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
 | 
			
		||||
 | 
			
		||||
  auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
 | 
			
		||||
      .catchAllKernel([] (Tensor, std::string) {
 | 
			
		||||
        called = true;
 | 
			
		||||
      }));
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  called = false;
 | 
			
		||||
  auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
 | 
			
		||||
  EXPECT_FALSE(called);
 | 
			
		||||
  EXPECT_EQ("hello _test::dummy", stack[1].toString()->string());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool called_autograd = false;
 | 
			
		||||
bool called_nonautograd = false;
 | 
			
		||||
 | 
			
		||||
@ -835,20 +819,6 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_th
 | 
			
		||||
  EXPECT_TRUE(called_autograd);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
 | 
			
		||||
  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
 | 
			
		||||
    .kernel<decltype(nonautograd_kernel), nonautograd_kernel>(DispatchKey::CPU)
 | 
			
		||||
    .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
 | 
			
		||||
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  called_nonautograd = called_autograd = false;
 | 
			
		||||
  op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU));
 | 
			
		||||
  EXPECT_TRUE(called_nonautograd);
 | 
			
		||||
  EXPECT_FALSE(called_autograd);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) {
 | 
			
		||||
  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
 | 
			
		||||
    .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
 | 
			
		||||
@ -857,10 +827,11 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  // catchAll now maps to Math which has higher precedence than Autograd
 | 
			
		||||
  called_nonautograd = called_autograd = false;
 | 
			
		||||
  op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
 | 
			
		||||
  EXPECT_FALSE(called_nonautograd);
 | 
			
		||||
  EXPECT_TRUE(called_autograd);
 | 
			
		||||
  EXPECT_TRUE(called_nonautograd);
 | 
			
		||||
  EXPECT_FALSE(called_autograd);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
 | 
			
		||||
@ -1627,6 +1598,39 @@ TEST(NewOperatorRegistrationTest, schema) {
 | 
			
		||||
  ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""})->schema().isDefaultAliasAnalysisKind());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
 | 
			
		||||
  auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
 | 
			
		||||
  m1.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
 | 
			
		||||
 | 
			
		||||
  bool called = false;
 | 
			
		||||
  auto m = MAKE_TORCH_LIBRARY(test);
 | 
			
		||||
  m.def("fn(Tensor t, str input) -> ()");
 | 
			
		||||
  m.impl("fn", [&] (Tensor, std::string) { called = true; });
 | 
			
		||||
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  called = false;
 | 
			
		||||
  auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
 | 
			
		||||
  // CatchAll now maps to Math and has higher precedence than backend fallback.
 | 
			
		||||
  EXPECT_TRUE(called);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
 | 
			
		||||
  auto m = MAKE_TORCH_LIBRARY(test);
 | 
			
		||||
  m.def("fn(Tensor dummy) -> ()");
 | 
			
		||||
  m.impl("fn", c10::DispatchKey::CPU, nonautograd_kernel);
 | 
			
		||||
  m.impl("fn", c10::DispatchKey::Autograd, autograd_kernel);
 | 
			
		||||
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  called_nonautograd = called_autograd = false;
 | 
			
		||||
  callOp(*op, dummyTensor(DispatchKey::CPU));
 | 
			
		||||
  EXPECT_TRUE(called_nonautograd);
 | 
			
		||||
  EXPECT_FALSE(called_autograd);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) {
 | 
			
		||||
  bool math_called = false;
 | 
			
		||||
  auto m = MAKE_TORCH_LIBRARY(test);
 | 
			
		||||
@ -1708,18 +1712,20 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) {
 | 
			
		||||
  auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
 | 
			
		||||
  ASSERT_TRUE(op.has_value());
 | 
			
		||||
 | 
			
		||||
  // catchAll now maps to Math, which means we have two registrations to Math key.
 | 
			
		||||
  // The last registration is used.
 | 
			
		||||
  {
 | 
			
		||||
    catchall_called = math_called = false;
 | 
			
		||||
    callOp(*op, dummyTensor(c10::DispatchKey::CPU));
 | 
			
		||||
    ASSERT_TRUE(math_called);
 | 
			
		||||
    ASSERT_FALSE(catchall_called);
 | 
			
		||||
    ASSERT_FALSE(math_called);
 | 
			
		||||
    ASSERT_TRUE(catchall_called);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    catchall_called = math_called = false;
 | 
			
		||||
    callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
 | 
			
		||||
    ASSERT_TRUE(math_called);
 | 
			
		||||
    ASSERT_FALSE(catchall_called);
 | 
			
		||||
    ASSERT_FALSE(math_called);
 | 
			
		||||
    ASSERT_TRUE(catchall_called);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2055,6 +2061,10 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther
 | 
			
		||||
TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
 | 
			
		||||
  bool privateuse1_called = false;
 | 
			
		||||
  bool catchall_called = false;
 | 
			
		||||
  // Similar to in-tree AutogradCPU/AutogradCUDA etc, out-of-tree backends usually register
 | 
			
		||||
  // a fallthrough kernel for AutogradPrivateUse1.
 | 
			
		||||
  auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1);
 | 
			
		||||
  m1.fallback(CppFunction::makeFallthrough());
 | 
			
		||||
 | 
			
		||||
  auto m = MAKE_TORCH_LIBRARY(test);
 | 
			
		||||
  m.def("fn", torch::dispatch(c10::DispatchKey::PrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }));
 | 
			
		||||
 | 
			
		||||
@ -67,9 +67,4 @@ TORCH_LIBRARY(aten, m) {
 | 
			
		||||
TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) {
 | 
			
		||||
  ${default_backend_function_registrations};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_IMPL(aten, Math, m) {
 | 
			
		||||
  ${math_function_registrations};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace at
 | 
			
		||||
 | 
			
		||||
@ -244,7 +244,7 @@ alias analysis kind: FROM_SCHEMA
 | 
			
		||||
CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_def_impl_schema_mismatch(self):
 | 
			
		||||
@ -277,7 +277,7 @@ alias analysis kind: CONSERVATIVE
 | 
			
		||||
CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_def_only(self):
 | 
			
		||||
@ -309,7 +309,7 @@ schema: (none)
 | 
			
		||||
CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table(self):
 | 
			
		||||
@ -335,24 +335,24 @@ CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
XLA: fn_xla :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
AutogradCPU: fn_autogradcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
        # computed dispatch table is too big, so we only check on a few entries we're interested in.
 | 
			
		||||
        extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(extracted_table, '''\
 | 
			
		||||
Undefined: default_def_name_t_t [catch all]
 | 
			
		||||
Undefined: default_def_name_t_t [math kernel]
 | 
			
		||||
CPU: fn_cpu [kernel]
 | 
			
		||||
CUDA: default_def_name_t_t [catch all]
 | 
			
		||||
CUDA: default_def_name_t_t [math kernel]
 | 
			
		||||
XLA: fn_xla [kernel]
 | 
			
		||||
AutogradOther: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradOther: default_def_name_t_t [math kernel]
 | 
			
		||||
AutogradCPU: fn_autogradcpu [kernel]
 | 
			
		||||
AutogradCUDA: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradCUDA: default_def_name_t_t [math kernel]
 | 
			
		||||
AutogradXLA: fn_autograd [autograd kernel]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table_with_cpu_catchall(self):
 | 
			
		||||
    def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self):
 | 
			
		||||
        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
 | 
			
		||||
        result = self.commute("foo", [
 | 
			
		||||
            # m.def("foo", [](const Tensor & x) { return x })
 | 
			
		||||
@ -367,21 +367,21 @@ schema: test::foo(Tensor _0) -> (Tensor _0)
 | 
			
		||||
debug: registered at /dev/null:0
 | 
			
		||||
alias analysis kind: CONSERVATIVE
 | 
			
		||||
CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
        # computed dispatch table is too big, so we only check on a few entries we're interested in.
 | 
			
		||||
        extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(extracted_table, '''\
 | 
			
		||||
Undefined: default_def_name_t_t [catch all]
 | 
			
		||||
Undefined: default_def_name_t_t [math kernel]
 | 
			
		||||
CPU: impl_t_t [kernel]
 | 
			
		||||
CUDA: default_def_name_t_t [catch all]
 | 
			
		||||
XLA: default_def_name_t_t [catch all]
 | 
			
		||||
AutogradOther: default_def_name_t_t [catch all]
 | 
			
		||||
CUDA: default_def_name_t_t [math kernel]
 | 
			
		||||
XLA: default_def_name_t_t [math kernel]
 | 
			
		||||
AutogradOther: default_def_name_t_t [math kernel]
 | 
			
		||||
AutogradCPU: fallthrough registered in pytorch framework [backend fallback]
 | 
			
		||||
AutogradCUDA: default_def_name_t_t [catch all]
 | 
			
		||||
AutogradXLA: default_def_name_t_t [catch all]
 | 
			
		||||
AutogradCUDA: default_def_name_t_t [math kernel]
 | 
			
		||||
AutogradXLA: default_def_name_t_t [math kernel]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table_with_math(self):
 | 
			
		||||
@ -476,10 +476,11 @@ AutogradCUDA: impl_t_t [autograd kernel]
 | 
			
		||||
AutogradXLA: impl_t_t [autograd kernel]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table_with_cpu_autograd_math_catchall(self):
 | 
			
		||||
    # Now that catchAll maps to Math, registering to both catchAll and Math breaks commutativity.
 | 
			
		||||
    def test_computed_table_with_cpu_autograd_math(self):
 | 
			
		||||
        result = self.commute("foo", [
 | 
			
		||||
            # m.def("foo", [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.def_name_t_t("foo"),
 | 
			
		||||
            # m.def("foo(Tensor x) -> Tensor")
 | 
			
		||||
            lambda m: m.def_("foo(Tensor x) -> Tensor"),
 | 
			
		||||
            # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
 | 
			
		||||
            # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
 | 
			
		||||
@ -490,13 +491,12 @@ AutogradXLA: impl_t_t [autograd kernel]
 | 
			
		||||
        state, table = result.state, result.table
 | 
			
		||||
        self.assertExpectedInline(state, '''\
 | 
			
		||||
name: test::foo
 | 
			
		||||
schema: test::foo(Tensor _0) -> (Tensor _0)
 | 
			
		||||
schema: test::foo(Tensor x) -> (Tensor)
 | 
			
		||||
debug: registered at /dev/null:0
 | 
			
		||||
alias analysis kind: CONSERVATIVE
 | 
			
		||||
alias analysis kind: FROM_SCHEMA
 | 
			
		||||
CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
        # computed dispatch table is too big, so we only check on a few entries we're interested in.
 | 
			
		||||
@ -511,46 +511,12 @@ AutogradOther: fn_math [math kernel]
 | 
			
		||||
AutogradCPU: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradCUDA: fn_math [math kernel]
 | 
			
		||||
AutogradXLA: fn_math [math kernel]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table_with_cpu_autograd_catchall(self):
 | 
			
		||||
        result = self.commute("foo", [
 | 
			
		||||
            # m.def("foo", [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.def_name_t_t("foo"),
 | 
			
		||||
            # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
 | 
			
		||||
            # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
 | 
			
		||||
        ])
 | 
			
		||||
        state, table = result.state, result.table
 | 
			
		||||
        self.assertExpectedInline(state, '''\
 | 
			
		||||
name: test::foo
 | 
			
		||||
schema: test::foo(Tensor _0) -> (Tensor _0)
 | 
			
		||||
debug: registered at /dev/null:0
 | 
			
		||||
alias analysis kind: CONSERVATIVE
 | 
			
		||||
CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
        # computed dispatch table is too big, so we only check on a few entries we're interested in.
 | 
			
		||||
        extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(extracted_table, '''\
 | 
			
		||||
Undefined: default_def_name_t_t [catch all]
 | 
			
		||||
CPU: fn_cpu [kernel]
 | 
			
		||||
CUDA: default_def_name_t_t [catch all]
 | 
			
		||||
XLA: default_def_name_t_t [catch all]
 | 
			
		||||
AutogradOther: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradCPU: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradCUDA: fn_autograd [autograd kernel]
 | 
			
		||||
AutogradXLA: fn_autograd [autograd kernel]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    def test_computed_table_with_ambiguous_autogradother(self):
 | 
			
		||||
        result = self.commute("foo", [
 | 
			
		||||
            # m.def("foo", [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.def_name_t_t("foo"),
 | 
			
		||||
            # m.def("foo(Tensor x) -> Tensor")
 | 
			
		||||
            lambda m: m.def_("foo(Tensor x) -> Tensor"),
 | 
			
		||||
            # m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
 | 
			
		||||
            lambda m: m.impl_t_t("foo", "Math", debug="fn_math"),
 | 
			
		||||
            # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x })
 | 
			
		||||
@ -559,12 +525,11 @@ AutogradXLA: fn_autograd [autograd kernel]
 | 
			
		||||
        state, table = result.state, result.table
 | 
			
		||||
        self.assertExpectedInline(state, '''\
 | 
			
		||||
name: test::foo
 | 
			
		||||
schema: test::foo(Tensor _0) -> (Tensor _0)
 | 
			
		||||
schema: test::foo(Tensor x) -> (Tensor)
 | 
			
		||||
debug: registered at /dev/null:0
 | 
			
		||||
alias analysis kind: CONSERVATIVE
 | 
			
		||||
alias analysis kind: FROM_SCHEMA
 | 
			
		||||
QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
        # computed dispatch table is too big, so we only check on a few entries we're interested in.
 | 
			
		||||
@ -794,7 +759,7 @@ alias analysis kind: PURE_FUNCTION
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertTrue(False)
 | 
			
		||||
 | 
			
		||||
    def test_overwrite_catchall(self):
 | 
			
		||||
    def test_overwrite_math(self):
 | 
			
		||||
        ops = [
 | 
			
		||||
            lambda m: m.impl_t_t("foo", debug="fn1"),
 | 
			
		||||
            lambda m: m.impl_t_t("foo", debug="fn2"),
 | 
			
		||||
@ -805,8 +770,8 @@ alias analysis kind: PURE_FUNCTION
 | 
			
		||||
            '''\
 | 
			
		||||
name: test::foo
 | 
			
		||||
schema: (none)
 | 
			
		||||
catchall: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
catchall (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias]: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
Math[alias] (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
 | 
			
		||||
'''
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4452,6 +4452,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
 | 
			
		||||
                    self.assertEqual(output3, output1)
 | 
			
		||||
                    self.assertEqual(output3, output2)
 | 
			
		||||
 | 
			
		||||
        @unittest.skipIf(True, "Skip due to catchAll -> Math")
 | 
			
		||||
        def test_empty_meta(self):
 | 
			
		||||
            x = torch.empty_meta(2 ** 20, 2 ** 20)
 | 
			
		||||
            y = torch.empty_meta(2 ** 20)
 | 
			
		||||
 | 
			
		||||
@ -1049,15 +1049,13 @@ def main() -> None:
 | 
			
		||||
 | 
			
		||||
        'function_registrations': list(mapMaybe(
 | 
			
		||||
            compute_type_method(None, target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
 | 
			
		||||
            native_functions)) + list(mapMaybe(
 | 
			
		||||
                compute_type_method('Math', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
 | 
			
		||||
                native_functions)),
 | 
			
		||||
 | 
			
		||||
        'default_backend_function_registrations': list(mapMaybe(
 | 
			
		||||
            compute_type_method('DefaultBackend', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
 | 
			
		||||
            native_functions)),
 | 
			
		||||
 | 
			
		||||
        'math_function_registrations': list(mapMaybe(
 | 
			
		||||
            compute_type_method('Math', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
 | 
			
		||||
            native_functions)),
 | 
			
		||||
    })
 | 
			
		||||
    cpu_fm.write('Functions.h', lambda: {
 | 
			
		||||
        'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)),
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@
 | 
			
		||||
#include <torch/csrc/autograd/autograd.h>
 | 
			
		||||
#include <ATen/TracerMode.h>
 | 
			
		||||
#include <ATen/core/op_registration/op_registration.h>
 | 
			
		||||
#include <torch/library.h>
 | 
			
		||||
 | 
			
		||||
using namespace at;
 | 
			
		||||
using namespace torch::autograd::generated;
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user