mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change for all people who registered functions for that function; they must accept c10::SymInt where they previously accepted int64_t. This is not great. With this change, I accept old type registrations transparently. The idea is in several parts: - At the registration site, at compile time I have no idea whether or not if the function being registered has a SymInt schema or not. So I must defer the exact compatibility check. What I do instead is check if the function pointer registered to me has SymInt in the argument or not. If it does, I assume it is new-style and ensure it is also registered to a special sym_ slot on KernelFunction. If not, it only goes in the conventional slot. - At the dispatcher site, I know at compile time whether or not this is a SymInt function. If it is, I check for a sym_ slot on the KernelFunction, and preferentially use that. If no such slot exists, I then fall back to the regular slot... but I convert all SymInt arguments to int64_t arguments (doing assertions that no true symbolic integer was passed.) I can skip this test entirely if the function doesn't have any SymInts in it; in that case I know that only the original slot could have been registered. Fortunately, both branches of the short circuit typecheck, so I didn't have to use SFINAE or if-constexpr to make it work; just a plain if statement that I expect the compiler to optimize away. - Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way. To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now. I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557 Approved by: https://github.com/wconstab
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							ed46b9670e
						
					
				
				
					commit
					19e27b1556
				
			| @ -185,11 +185,6 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) | ||||
|   return self_physical.getPhysicalToLogicalMap().apply(result); | ||||
| } | ||||
|  | ||||
| Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) { | ||||
|   // TODO: properly support this | ||||
|   return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit); | ||||
| } | ||||
|  | ||||
| std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { | ||||
|   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); | ||||
|   auto dim_physical = self_physical.getPhysicalDim(dim); | ||||
| @ -469,11 +464,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { | ||||
|   return self_physical.getPhysicalToLogicalMap().apply(result); | ||||
| } | ||||
|  | ||||
| Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) { | ||||
|   // TODO: properly support this | ||||
|   return view_batching_rule(self, asIntArrayRefSlow(size)); | ||||
| } | ||||
|  | ||||
| Tensor view_as_complex_batching_rule(const Tensor& self) { | ||||
|   // guard against the user passing in a batch of scalar tensors with batch | ||||
|   // size equal to 2. | ||||
| @ -1004,17 +994,6 @@ Tensor new_empty_batching_rule( | ||||
|   return physical_view.getPhysicalToLogicalMap().apply(result); | ||||
| } | ||||
|  | ||||
| Tensor new_empty_symint_batching_rule( | ||||
|     const Tensor& self, | ||||
|     c10::SymIntArrayRef size, | ||||
|     c10::optional<ScalarType> dtype, | ||||
|     c10::optional<Layout> layout, | ||||
|     c10::optional<Device> device, | ||||
|     c10::optional<bool> pin_memory) { | ||||
|   // TODO: properly support this | ||||
|   return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory); | ||||
| } | ||||
|  | ||||
| Tensor new_empty_strided_batching_rule( | ||||
|     const Tensor& self, | ||||
|     IntArrayRef size, | ||||
| @ -1112,7 +1091,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { | ||||
|   m.impl("tensor_split.sections", tensor_split_sections_batching_rule); | ||||
|   m.impl("tensor_split.indices", tensor_split_indices_batching_rule); | ||||
|   m.impl("diagonal", diagonal_batching_rule); | ||||
|   m.impl("expand", expand_symint_batching_rule); | ||||
|   m.impl("expand", expand_batching_rule); | ||||
|   m.impl("expand_as", native::expand_as); // composite wrt autograd | ||||
|   m.impl("movedim.intlist", movedim_batching_rule); | ||||
|   m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd | ||||
| @ -1140,7 +1119,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { | ||||
|   m.impl("unbind.int", unbind_batching_rule); | ||||
|   m.impl("unfold", unfold_batching_rule); | ||||
|   m.impl("unsqueeze", unsqueeze_batching_rule); | ||||
|   m.impl("view", view_symint_batching_rule); | ||||
|   m.impl("view", view_batching_rule); | ||||
|   m.impl("view_as", native::view_as); // composite wrt autograd | ||||
|  | ||||
|   // clamp operations | ||||
| @ -1278,7 +1257,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { | ||||
|   m.impl("diagonal_backward", diagonal_backward_batching_rule); | ||||
|  | ||||
|   // Tensor.new_* operators | ||||
|   m.impl("new_empty", new_empty_symint_batching_rule); | ||||
|   m.impl("new_empty", new_empty_batching_rule); | ||||
|   m.impl("new_empty_strided", new_empty_strided_batching_rule); | ||||
|   m.impl("new_zeros", new_zeros_batching_rule); | ||||
|  | ||||
|  | ||||
| @ -14,6 +14,40 @@ class OperatorHandle; | ||||
| struct OperatorKernel; | ||||
| class KernelFunction; | ||||
|  | ||||
| template <typename T> | ||||
| using has_symint = | ||||
|   guts::disjunction< | ||||
|     std::is_same<c10::SymInt, std::decay_t<T>>, | ||||
|     std::is_same<c10::SymIntArrayRef, std::decay_t<T>>, | ||||
|     std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>> | ||||
|   >; | ||||
|  | ||||
| template <typename T> | ||||
| struct remove_symint { | ||||
|   using type = T; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct remove_symint<c10::SymInt> { | ||||
|   using type = int64_t; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct remove_symint<c10::SymIntArrayRef> { | ||||
|   using type = c10::IntArrayRef; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct remove_symint<c10::optional<c10::SymInt>> { | ||||
|   using type = c10::optional<int64_t>; | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| using fn_has_symint = typename guts::typelist::true_for_any_type< | ||||
|   has_symint, | ||||
|   typename guts::infer_function_traits<T>::type::parameter_types | ||||
| >; | ||||
|  | ||||
| /** | ||||
|  * KernelFunction is similar to std::function but stores a kernel function. | ||||
|  * You can create a KernelFunction from a boxed or unboxed function/functor/lambda | ||||
| @ -31,6 +65,7 @@ public: | ||||
|   // Fast path for dispatch to allow not touching the boxed kernel in | ||||
|   // the common case where unboxed is available. | ||||
|   bool isValidUnboxed() const; | ||||
|   bool isValidSymUnboxed() const; | ||||
|   bool isValid() const; | ||||
|   bool isFallthrough() const; | ||||
|  | ||||
| @ -182,13 +217,16 @@ private: | ||||
|   explicit KernelFunction( | ||||
|       std::unique_ptr<OperatorKernel> functor, | ||||
|       InternalBoxedKernelFunction* boxed_kernel_func, | ||||
|       void* unboxed_kernel_func); | ||||
|       void* unboxed_kernel_func, | ||||
|       void* sym_unboxed_kernel_func); | ||||
|   explicit KernelFunction( | ||||
|       BoxedKernel boxed_fn, | ||||
|       void* unboxed_kernel_func); | ||||
|       void* unboxed_kernel_func, | ||||
|       void* sym_unboxed_kernel_func); | ||||
|  | ||||
|   BoxedKernel boxed_kernel_func_; | ||||
|   void* unboxed_kernel_func_; | ||||
|   void* sym_unboxed_kernel_func_; | ||||
| }; | ||||
|  | ||||
| } | ||||
|  | ||||
| @ -8,22 +8,29 @@ namespace c10 { | ||||
| inline KernelFunction::KernelFunction() | ||||
|     : boxed_kernel_func_() | ||||
|     , unboxed_kernel_func_(nullptr) | ||||
|     , sym_unboxed_kernel_func_(nullptr) | ||||
| {} | ||||
|  | ||||
| inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func) | ||||
| inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) | ||||
|   : boxed_kernel_func_(std::move(functor), boxed_kernel_func) | ||||
|   , unboxed_kernel_func_(unboxed_kernel_func) | ||||
|   , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) | ||||
| {} | ||||
|  | ||||
| inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func) | ||||
| inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) | ||||
|   : boxed_kernel_func_(std::move(boxed_fn)) | ||||
|   , unboxed_kernel_func_(unboxed_kernel_func) | ||||
|   , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) | ||||
| {} | ||||
|  | ||||
| inline bool KernelFunction::isValidUnboxed() const { | ||||
|   return unboxed_kernel_func_ != nullptr; | ||||
| } | ||||
|  | ||||
| inline bool KernelFunction::isValidSymUnboxed() const { | ||||
|   return sym_unboxed_kernel_func_ != nullptr; | ||||
| } | ||||
|  | ||||
| inline bool KernelFunction::isValid() const { | ||||
|   return boxed_kernel_func_.isValid(); | ||||
| } | ||||
| @ -43,16 +50,52 @@ inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKerne | ||||
|     return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...); | ||||
| } | ||||
|  | ||||
| // This template requires you to explicitly specify the argument you want to | ||||
| // forward; it doesn't work if you try to deduce it | ||||
|  | ||||
| template <typename T> | ||||
| inline typename remove_symint<T>::type unpackSymInt(T x) { return x; } | ||||
|  | ||||
| template <> | ||||
| inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) { | ||||
|   return x.expect_int(); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) { | ||||
|   return c10::asIntArrayRefSlow(x); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) { | ||||
|   return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt; | ||||
| } | ||||
|  | ||||
| template<class Return, class... Args> | ||||
| C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const { | ||||
|     // note: Args above is intentionally not Args&&. We don't want perfect | ||||
|     // forwarding, which would require Args to be deduced, but instead we | ||||
|     // want callers to explicitly specify the Args. | ||||
|  | ||||
|     if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { | ||||
|       auto *functor = boxed_kernel_func_.getFunctor(); | ||||
|       return callUnboxedKernelFunction<Return, Args...>( | ||||
|           unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...); | ||||
|     // This should get inlined by compiler | ||||
|     if (guts::disjunction<has_symint<Args>...>::value) { | ||||
|       if (sym_unboxed_kernel_func_ != nullptr) { | ||||
|         auto *functor = boxed_kernel_func_.getFunctor(); | ||||
|         return callUnboxedKernelFunction<Return, Args...>( | ||||
|             sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...); | ||||
|       } | ||||
|  | ||||
|       if (unboxed_kernel_func_ != nullptr) { | ||||
|         auto *functor = boxed_kernel_func_.getFunctor(); | ||||
|         return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>( | ||||
|             unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...); | ||||
|       } | ||||
|     } else { | ||||
|       if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { | ||||
|         auto *functor = boxed_kernel_func_.getFunctor(); | ||||
|         return callUnboxedKernelFunction<Return, Args...>( | ||||
|             unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     return impl::BoxedKernelWrapper<Return(Args...)>::call( | ||||
| @ -102,10 +145,14 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<Ope | ||||
| #endif | ||||
|     static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); | ||||
|  | ||||
|     auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call; | ||||
|     void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn); | ||||
|     bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value; | ||||
|     return KernelFunction( | ||||
|         std::move(kernelFunctor), | ||||
|         &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call, | ||||
|         reinterpret_cast<void*>(&impl::wrap_kernel_functor_unboxed<KernelFunctor>::call) | ||||
|         is_symint ? nullptr : void_unboxed_fn, | ||||
|         is_symint ? void_unboxed_fn : nullptr | ||||
|     ); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -26,6 +26,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) | ||||
| , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) | ||||
| , kernels_() | ||||
| , cpp_signature_() | ||||
| , sym_cpp_signature_() | ||||
| , is_observed_(ObservedOperators::isObserved(name_)) | ||||
| { | ||||
|   // Pick up any backend fallbacks that were registered prior to this | ||||
| @ -34,12 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|   void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) { | ||||
|   void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) { | ||||
|     // TODO: figure out if we can just directly save real schema at def time | ||||
|     c10::optional<std::string> schema_difference = findSchemaDifferences( | ||||
|       from_def.cloneWithRealTypes(), | ||||
|       inferred.cloneWithRealTypes() | ||||
|     ); | ||||
|     FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed()); | ||||
|     FunctionSchema inferred = inferred_.cloneWithRealTypes(); | ||||
|     c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred); | ||||
|     if (schema_difference.has_value()) { | ||||
|       TORCH_CHECK(false, | ||||
|         "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n" | ||||
| @ -64,12 +64,24 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const { | ||||
|   return kernel; | ||||
| } | ||||
|  | ||||
| void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const { | ||||
|   if (has_symint) { | ||||
|     if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) { | ||||
|       reportSignatureError(call_signature, *sym_cpp_signature_); | ||||
|     } | ||||
|   } else { | ||||
|     if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) { | ||||
|       reportSignatureError(call_signature, *cpp_signature_); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) { | ||||
|   TORCH_INTERNAL_ASSERT(!schema_.has_value()); | ||||
|   for (const auto& kernel : kernels_) { | ||||
|     for (const auto &j : kernel.second) { | ||||
|       if (j.inferred_function_schema != nullptr) { | ||||
|         checkSchema(name_, schema, debug, *j.inferred_function_schema, j.debug); | ||||
|         checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -103,25 +115,26 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( | ||||
|   // which means if you could validly change the type of a cpp_signature, then | ||||
|   // that would also invalidate the old TypedOperatorHandles. | ||||
|   if (cpp_signature.has_value()) { | ||||
|     if (cpp_signature_.has_value()) { | ||||
|       TORCH_CHECK(*cpp_signature == cpp_signature_->signature, | ||||
|     auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_; | ||||
|     if (local_cpp_signature.has_value()) { | ||||
|       TORCH_CHECK(*cpp_signature == local_cpp_signature->signature, | ||||
|         "\nMismatch in kernel C++ signatures\n", | ||||
|         "  operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n", | ||||
|         "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", | ||||
|         "  kernel 1: ", cpp_signature_->signature.name(), "\n", | ||||
|         "    dispatch key: ", toString(cpp_signature_->dispatch_key), "\n", | ||||
|         "    ", cpp_signature_->debug, "\n", | ||||
|         "  kernel 1: ", local_cpp_signature->signature.name(), "\n", | ||||
|         "    dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n", | ||||
|         "    ", local_cpp_signature->debug, "\n", | ||||
|         "  kernel 2: ", cpp_signature->name(), "\n", | ||||
|         "    dispatch key: ", toString(dispatch_key), "\n", | ||||
|         "    ", debug, "\n" | ||||
|       ); | ||||
|     } else { | ||||
|       cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; | ||||
|       local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (schema_ && inferred_function_schema) { | ||||
|     checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug); | ||||
|     checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug); | ||||
|   } | ||||
|  | ||||
|   // Add the kernel to the kernels list, | ||||
| @ -138,7 +151,7 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( | ||||
|                "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", | ||||
|                "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", | ||||
|                "  dispatch key: ", toString(dispatch_key), "\n", | ||||
|                "  previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n", | ||||
|                "  previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n", | ||||
|                "       new kernel: ", debug | ||||
|     ); | ||||
|   } | ||||
| @ -471,13 +484,13 @@ std::string OperatorEntry::listAllDispatchKeys() const { | ||||
|   return str.str(); | ||||
| } | ||||
|  | ||||
| void OperatorEntry::reportSignatureError(const CppSignature call_signature) const { | ||||
| void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const { | ||||
|   TORCH_CHECK(false, | ||||
|         "\nTried to access or call an operator with a wrong signature.\n", | ||||
|         "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", | ||||
|         "    ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n", | ||||
|         "  correct signature:  ", cpp_signature_->signature.name(), "\n", | ||||
|         "    ", cpp_signature_->debug, "\n", | ||||
|         "  correct signature:  ", saved_signature.signature.name(), "\n", | ||||
|         "    ", saved_signature.debug, "\n", | ||||
|         "  accessed/called as: ", call_signature.name(), "\n", | ||||
|         "This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ", | ||||
|         "Please make sure that the function signature matches the signature in the operator registration call." | ||||
|  | ||||
| @ -163,14 +163,10 @@ public: | ||||
|   // Asserts that the given FuncType is correct for calling this operator in an unboxed way. | ||||
|   template<class FuncType> | ||||
|   inline void assertSignatureIsCorrect() { | ||||
|     assertSignatureIsCorrect(CppSignature::make<FuncType>()); | ||||
|     assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value); | ||||
|   } | ||||
|  | ||||
|   void assertSignatureIsCorrect(const CppSignature call_signature) { | ||||
|     if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) { | ||||
|       reportSignatureError(call_signature); | ||||
|     } | ||||
|   } | ||||
|   void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const; | ||||
|  | ||||
|   [[noreturn]] void reportError(DispatchKey dispatchKey) const; | ||||
|  | ||||
| @ -280,11 +276,12 @@ private: | ||||
|     c10::optional<DispatchKey> dispatch_key; | ||||
|   }; | ||||
|   c10::optional<CppSignatureWithDebug> cpp_signature_; | ||||
|   c10::optional<CppSignatureWithDebug> sym_cpp_signature_; | ||||
|  | ||||
|   // Whether this operator needs to be observed with RecordFunction | ||||
|   const bool is_observed_; | ||||
|  | ||||
|   [[noreturn]] void reportSignatureError(CppSignature call_signature) const; | ||||
|   [[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const; | ||||
|   const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const; | ||||
|   std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug( | ||||
|     const c10::Dispatcher& dispatcher, DispatchKey dispatch_key | ||||
|  | ||||
| @ -17,9 +17,23 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) | ||||
|   } | ||||
| } | ||||
|  | ||||
| FunctionSchema FunctionSchema::cloneWithRealTypes() const { | ||||
|   auto cloneWithRealTypes = [](const Argument& a) { | ||||
|     return a.cloneWithType(a.real_type()); | ||||
| FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const { | ||||
|   auto cloneWithRealTypes = [&](const Argument& a) { | ||||
|     if (with_symint) { | ||||
|       return a.cloneWithType(a.real_type()); | ||||
|     } | ||||
|     // Don't use real type if it looks like a SymInt | ||||
|     // NB: keep this in sync with unpackSymInt in KernelFunction_impl.h | ||||
|     if ( | ||||
|       *a.real_type() == *getTypePtr<c10::SymInt>() || | ||||
|       *a.real_type() == *getTypePtr<c10::optional<c10::SymInt>>() || | ||||
|       *a.real_type() == *getTypePtr<c10::SymIntArrayRef>() | ||||
|     ) { | ||||
|       // Keep the fake type | ||||
|       return a.cloneWithType(a.type()); | ||||
|     } else { | ||||
|       return a.cloneWithType(a.real_type()); | ||||
|     } | ||||
|   }; | ||||
|   std::vector<Argument> new_arguments, new_returns; | ||||
|   std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes); | ||||
|  | ||||
| @ -474,7 +474,7 @@ struct TORCH_API FunctionSchema { | ||||
|   FunctionSchema cloneWithRemappedTypes( | ||||
|       const std::function<TypePtr(TypePtr)> type_map) const; | ||||
|  | ||||
|   FunctionSchema cloneWithRealTypes() const; | ||||
|   FunctionSchema cloneWithRealTypes(bool with_symint=true) const; | ||||
|  | ||||
|   // Check that inputs have the correct types and appends any missing default | ||||
|   // values. | ||||
|  | ||||
| @ -29,13 +29,12 @@ Tensor _empty_affine_quantized( | ||||
| } | ||||
|  | ||||
| Tensor empty_memory_format( | ||||
|     const SymIntArrayRef sym_sizes, | ||||
|     const IntArrayRef sizes, | ||||
|     const c10::optional<ScalarType> dtype, | ||||
|     const c10::optional<c10::Layout> layout, | ||||
|     const c10::optional<Device> device, | ||||
|     const c10::optional<bool> pin_memory, | ||||
|     const optional<MemoryFormat> memory_format) { | ||||
|   auto sizes = c10::asIntArrayRefSlow(sym_sizes); | ||||
|   return convert(vTensor{ | ||||
|       api::context(), | ||||
|       sizes, | ||||
| @ -56,12 +55,7 @@ Tensor empty_strided( | ||||
|     const optional<Device> device, | ||||
|     const optional<bool> pin_memory) { | ||||
|   return empty_memory_format( | ||||
|       c10::SymIntArrayRef::fromIntArrayRef(sizes), | ||||
|       dtype, | ||||
|       layout, | ||||
|       device, | ||||
|       pin_memory, | ||||
|       c10::MemoryFormat::Contiguous); | ||||
|       sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous); | ||||
| } | ||||
|  | ||||
| #ifdef USE_VULKAN_API | ||||
|  | ||||
| @ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { | ||||
|   return convert(v_output); | ||||
| } | ||||
|  | ||||
| inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) { | ||||
|   auto shape = c10::asIntArrayRefSlow(sym_shape); | ||||
| inline Tensor view(const Tensor& self_arg, IntArrayRef shape) { | ||||
|   return view_internal(self_arg, shape); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -439,12 +439,6 @@ std::tuple<Tensor, optional<int64_t>> view_batching_rule( | ||||
|   return std::make_tuple(self_.view_symint(size_), 0); | ||||
| } | ||||
|  | ||||
| Tensor view_symint_decomposition(const Tensor& self, | ||||
|             c10::SymIntArrayRef size) { | ||||
|   return self.view( c10::asIntArrayRefSlow(size)); | ||||
| } | ||||
|  | ||||
|  | ||||
| template <typename F, F Func> | ||||
| std::tuple<Tensor, optional<int64_t>> expand_batch_rule( | ||||
|     const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit) | ||||
| @ -512,14 +506,6 @@ std::tuple<Tensor, optional<int64_t>> diag_embed_batch_rule(const Tensor& self, | ||||
|   return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0); | ||||
| } | ||||
|  | ||||
| // We need to write a real batching rule to fully support symint. | ||||
| // This requires symint variants of other operations, like `view`, | ||||
| // which don't exist yet. | ||||
| Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) { | ||||
|    auto size = asIntArrayRefSlow(packed_size); | ||||
|    return self.expand(size, implicit); | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { | ||||
|   VMAP_SUPPORT(diag, diag_batch_rule); | ||||
|   VMAP_SUPPORT(chunk, chunk_batching_rule); | ||||
|  | ||||
| @ -49,9 +49,9 @@ at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional<at::Sc | ||||
|   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); | ||||
|   return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); | ||||
| } | ||||
| at::Tensor custom_empty_symint(c10::SymIntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) { | ||||
| at::Tensor custom_empty_symint(c10::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) { | ||||
|   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); | ||||
|   return at::detail::empty_generic(c10::asIntArrayRefSlow(size), &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); | ||||
|   return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); | ||||
| } | ||||
|  | ||||
| at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { | ||||
|  | ||||
| @ -20,10 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { | ||||
|   return Tensor(std::move(tensor_impl)); | ||||
| } | ||||
|  | ||||
| Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, | ||||
| Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, | ||||
|                       c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) { | ||||
|   test_int = 0; | ||||
|   return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size)); | ||||
|   return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); | ||||
| } | ||||
|  | ||||
| Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) { | ||||
|  | ||||
		Reference in New Issue
	
	Block a user