mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change for all people who registered functions for that function; they must accept c10::SymInt where they previously accepted int64_t. This is not great. With this change, I accept old type registrations transparently. The idea is in several parts: - At the registration site, at compile time I have no idea whether or not if the function being registered has a SymInt schema or not. So I must defer the exact compatibility check. What I do instead is check if the function pointer registered to me has SymInt in the argument or not. If it does, I assume it is new-style and ensure it is also registered to a special sym_ slot on KernelFunction. If not, it only goes in the conventional slot. - At the dispatcher site, I know at compile time whether or not this is a SymInt function. If it is, I check for a sym_ slot on the KernelFunction, and preferentially use that. If no such slot exists, I then fall back to the regular slot... but I convert all SymInt arguments to int64_t arguments (doing assertions that no true symbolic integer was passed.) I can skip this test entirely if the function doesn't have any SymInts in it; in that case I know that only the original slot could have been registered. Fortunately, both branches of the short circuit typecheck, so I didn't have to use SFINAE or if-constexpr to make it work; just a plain if statement that I expect the compiler to optimize away. - Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way. To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now. I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
ed46b9670e
commit
19e27b1556
@ -185,11 +185,6 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
|
||||
// TODO: properly support this
|
||||
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
|
||||
}
|
||||
|
||||
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto dim_physical = self_physical.getPhysicalDim(dim);
|
||||
@ -469,11 +464,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
|
||||
// TODO: properly support this
|
||||
return view_batching_rule(self, asIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
Tensor view_as_complex_batching_rule(const Tensor& self) {
|
||||
// guard against the user passing in a batch of scalar tensors with batch
|
||||
// size equal to 2.
|
||||
@ -1004,17 +994,6 @@ Tensor new_empty_batching_rule(
|
||||
return physical_view.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor new_empty_symint_batching_rule(
|
||||
const Tensor& self,
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
// TODO: properly support this
|
||||
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
Tensor new_empty_strided_batching_rule(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
@ -1112,7 +1091,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
|
||||
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
|
||||
m.impl("diagonal", diagonal_batching_rule);
|
||||
m.impl("expand", expand_symint_batching_rule);
|
||||
m.impl("expand", expand_batching_rule);
|
||||
m.impl("expand_as", native::expand_as); // composite wrt autograd
|
||||
m.impl("movedim.intlist", movedim_batching_rule);
|
||||
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
|
||||
@ -1140,7 +1119,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||
m.impl("unbind.int", unbind_batching_rule);
|
||||
m.impl("unfold", unfold_batching_rule);
|
||||
m.impl("unsqueeze", unsqueeze_batching_rule);
|
||||
m.impl("view", view_symint_batching_rule);
|
||||
m.impl("view", view_batching_rule);
|
||||
m.impl("view_as", native::view_as); // composite wrt autograd
|
||||
|
||||
// clamp operations
|
||||
@ -1278,7 +1257,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||
m.impl("diagonal_backward", diagonal_backward_batching_rule);
|
||||
|
||||
// Tensor.new_* operators
|
||||
m.impl("new_empty", new_empty_symint_batching_rule);
|
||||
m.impl("new_empty", new_empty_batching_rule);
|
||||
m.impl("new_empty_strided", new_empty_strided_batching_rule);
|
||||
m.impl("new_zeros", new_zeros_batching_rule);
|
||||
|
||||
|
||||
@ -14,6 +14,40 @@ class OperatorHandle;
|
||||
struct OperatorKernel;
|
||||
class KernelFunction;
|
||||
|
||||
template <typename T>
|
||||
using has_symint =
|
||||
guts::disjunction<
|
||||
std::is_same<c10::SymInt, std::decay_t<T>>,
|
||||
std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
|
||||
std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
|
||||
>;
|
||||
|
||||
template <typename T>
|
||||
struct remove_symint {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct remove_symint<c10::SymInt> {
|
||||
using type = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct remove_symint<c10::SymIntArrayRef> {
|
||||
using type = c10::IntArrayRef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct remove_symint<c10::optional<c10::SymInt>> {
|
||||
using type = c10::optional<int64_t>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using fn_has_symint = typename guts::typelist::true_for_any_type<
|
||||
has_symint,
|
||||
typename guts::infer_function_traits<T>::type::parameter_types
|
||||
>;
|
||||
|
||||
/**
|
||||
* KernelFunction is similar to std::function but stores a kernel function.
|
||||
* You can create a KernelFunction from a boxed or unboxed function/functor/lambda
|
||||
@ -31,6 +65,7 @@ public:
|
||||
// Fast path for dispatch to allow not touching the boxed kernel in
|
||||
// the common case where unboxed is available.
|
||||
bool isValidUnboxed() const;
|
||||
bool isValidSymUnboxed() const;
|
||||
bool isValid() const;
|
||||
bool isFallthrough() const;
|
||||
|
||||
@ -182,13 +217,16 @@ private:
|
||||
explicit KernelFunction(
|
||||
std::unique_ptr<OperatorKernel> functor,
|
||||
InternalBoxedKernelFunction* boxed_kernel_func,
|
||||
void* unboxed_kernel_func);
|
||||
void* unboxed_kernel_func,
|
||||
void* sym_unboxed_kernel_func);
|
||||
explicit KernelFunction(
|
||||
BoxedKernel boxed_fn,
|
||||
void* unboxed_kernel_func);
|
||||
void* unboxed_kernel_func,
|
||||
void* sym_unboxed_kernel_func);
|
||||
|
||||
BoxedKernel boxed_kernel_func_;
|
||||
void* unboxed_kernel_func_;
|
||||
void* sym_unboxed_kernel_func_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@ -8,22 +8,29 @@ namespace c10 {
|
||||
inline KernelFunction::KernelFunction()
|
||||
: boxed_kernel_func_()
|
||||
, unboxed_kernel_func_(nullptr)
|
||||
, sym_unboxed_kernel_func_(nullptr)
|
||||
{}
|
||||
|
||||
inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func)
|
||||
inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
|
||||
: boxed_kernel_func_(std::move(functor), boxed_kernel_func)
|
||||
, unboxed_kernel_func_(unboxed_kernel_func)
|
||||
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
|
||||
{}
|
||||
|
||||
inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func)
|
||||
inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
|
||||
: boxed_kernel_func_(std::move(boxed_fn))
|
||||
, unboxed_kernel_func_(unboxed_kernel_func)
|
||||
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
|
||||
{}
|
||||
|
||||
inline bool KernelFunction::isValidUnboxed() const {
|
||||
return unboxed_kernel_func_ != nullptr;
|
||||
}
|
||||
|
||||
inline bool KernelFunction::isValidSymUnboxed() const {
|
||||
return sym_unboxed_kernel_func_ != nullptr;
|
||||
}
|
||||
|
||||
inline bool KernelFunction::isValid() const {
|
||||
return boxed_kernel_func_.isValid();
|
||||
}
|
||||
@ -43,16 +50,52 @@ inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKerne
|
||||
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// This template requires you to explicitly specify the argument you want to
|
||||
// forward; it doesn't work if you try to deduce it
|
||||
|
||||
template <typename T>
|
||||
inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
||||
return x.expect_int();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) {
|
||||
return c10::asIntArrayRefSlow(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
|
||||
return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
|
||||
}
|
||||
|
||||
template<class Return, class... Args>
|
||||
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
|
||||
// note: Args above is intentionally not Args&&. We don't want perfect
|
||||
// forwarding, which would require Args to be deduced, but instead we
|
||||
// want callers to explicitly specify the Args.
|
||||
|
||||
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
|
||||
auto *functor = boxed_kernel_func_.getFunctor();
|
||||
return callUnboxedKernelFunction<Return, Args...>(
|
||||
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
|
||||
// This should get inlined by compiler
|
||||
if (guts::disjunction<has_symint<Args>...>::value) {
|
||||
if (sym_unboxed_kernel_func_ != nullptr) {
|
||||
auto *functor = boxed_kernel_func_.getFunctor();
|
||||
return callUnboxedKernelFunction<Return, Args...>(
|
||||
sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
if (unboxed_kernel_func_ != nullptr) {
|
||||
auto *functor = boxed_kernel_func_.getFunctor();
|
||||
return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>(
|
||||
unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...);
|
||||
}
|
||||
} else {
|
||||
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
|
||||
auto *functor = boxed_kernel_func_.getFunctor();
|
||||
return callUnboxedKernelFunction<Return, Args...>(
|
||||
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
return impl::BoxedKernelWrapper<Return(Args...)>::call(
|
||||
@ -102,10 +145,14 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<Ope
|
||||
#endif
|
||||
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
||||
|
||||
auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
|
||||
void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
|
||||
bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
|
||||
return KernelFunction(
|
||||
std::move(kernelFunctor),
|
||||
&impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call,
|
||||
reinterpret_cast<void*>(&impl::wrap_kernel_functor_unboxed<KernelFunctor>::call)
|
||||
is_symint ? nullptr : void_unboxed_fn,
|
||||
is_symint ? void_unboxed_fn : nullptr
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
||||
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
|
||||
, kernels_()
|
||||
, cpp_signature_()
|
||||
, sym_cpp_signature_()
|
||||
, is_observed_(ObservedOperators::isObserved(name_))
|
||||
{
|
||||
// Pick up any backend fallbacks that were registered prior to this
|
||||
@ -34,12 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
||||
}
|
||||
|
||||
namespace {
|
||||
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
|
||||
void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) {
|
||||
// TODO: figure out if we can just directly save real schema at def time
|
||||
c10::optional<std::string> schema_difference = findSchemaDifferences(
|
||||
from_def.cloneWithRealTypes(),
|
||||
inferred.cloneWithRealTypes()
|
||||
);
|
||||
FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed());
|
||||
FunctionSchema inferred = inferred_.cloneWithRealTypes();
|
||||
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
|
||||
if (schema_difference.has_value()) {
|
||||
TORCH_CHECK(false,
|
||||
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
|
||||
@ -64,12 +64,24 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
|
||||
return kernel;
|
||||
}
|
||||
|
||||
void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const {
|
||||
if (has_symint) {
|
||||
if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
|
||||
reportSignatureError(call_signature, *sym_cpp_signature_);
|
||||
}
|
||||
} else {
|
||||
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
|
||||
reportSignatureError(call_signature, *cpp_signature_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
|
||||
TORCH_INTERNAL_ASSERT(!schema_.has_value());
|
||||
for (const auto& kernel : kernels_) {
|
||||
for (const auto &j : kernel.second) {
|
||||
if (j.inferred_function_schema != nullptr) {
|
||||
checkSchema(name_, schema, debug, *j.inferred_function_schema, j.debug);
|
||||
checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -103,25 +115,26 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
|
||||
// which means if you could validly change the type of a cpp_signature, then
|
||||
// that would also invalidate the old TypedOperatorHandles.
|
||||
if (cpp_signature.has_value()) {
|
||||
if (cpp_signature_.has_value()) {
|
||||
TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
|
||||
auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_;
|
||||
if (local_cpp_signature.has_value()) {
|
||||
TORCH_CHECK(*cpp_signature == local_cpp_signature->signature,
|
||||
"\nMismatch in kernel C++ signatures\n",
|
||||
" operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
|
||||
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
|
||||
" kernel 1: ", cpp_signature_->signature.name(), "\n",
|
||||
" dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
|
||||
" ", cpp_signature_->debug, "\n",
|
||||
" kernel 1: ", local_cpp_signature->signature.name(), "\n",
|
||||
" dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n",
|
||||
" ", local_cpp_signature->debug, "\n",
|
||||
" kernel 2: ", cpp_signature->name(), "\n",
|
||||
" dispatch key: ", toString(dispatch_key), "\n",
|
||||
" ", debug, "\n"
|
||||
);
|
||||
} else {
|
||||
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
|
||||
local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
|
||||
}
|
||||
}
|
||||
|
||||
if (schema_ && inferred_function_schema) {
|
||||
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
|
||||
checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug);
|
||||
}
|
||||
|
||||
// Add the kernel to the kernels list,
|
||||
@ -138,7 +151,7 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
|
||||
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
|
||||
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
|
||||
" dispatch key: ", toString(dispatch_key), "\n",
|
||||
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
|
||||
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n",
|
||||
" new kernel: ", debug
|
||||
);
|
||||
}
|
||||
@ -471,13 +484,13 @@ std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
return str.str();
|
||||
}
|
||||
|
||||
void OperatorEntry::reportSignatureError(const CppSignature call_signature) const {
|
||||
void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const {
|
||||
TORCH_CHECK(false,
|
||||
"\nTried to access or call an operator with a wrong signature.\n",
|
||||
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
|
||||
" ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n",
|
||||
" correct signature: ", cpp_signature_->signature.name(), "\n",
|
||||
" ", cpp_signature_->debug, "\n",
|
||||
" correct signature: ", saved_signature.signature.name(), "\n",
|
||||
" ", saved_signature.debug, "\n",
|
||||
" accessed/called as: ", call_signature.name(), "\n",
|
||||
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
|
||||
"Please make sure that the function signature matches the signature in the operator registration call."
|
||||
|
||||
@ -163,14 +163,10 @@ public:
|
||||
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
|
||||
template<class FuncType>
|
||||
inline void assertSignatureIsCorrect() {
|
||||
assertSignatureIsCorrect(CppSignature::make<FuncType>());
|
||||
assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
|
||||
}
|
||||
|
||||
void assertSignatureIsCorrect(const CppSignature call_signature) {
|
||||
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
|
||||
reportSignatureError(call_signature);
|
||||
}
|
||||
}
|
||||
void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const;
|
||||
|
||||
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
|
||||
|
||||
@ -280,11 +276,12 @@ private:
|
||||
c10::optional<DispatchKey> dispatch_key;
|
||||
};
|
||||
c10::optional<CppSignatureWithDebug> cpp_signature_;
|
||||
c10::optional<CppSignatureWithDebug> sym_cpp_signature_;
|
||||
|
||||
// Whether this operator needs to be observed with RecordFunction
|
||||
const bool is_observed_;
|
||||
|
||||
[[noreturn]] void reportSignatureError(CppSignature call_signature) const;
|
||||
[[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const;
|
||||
const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
|
||||
std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
|
||||
const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
|
||||
|
||||
@ -17,9 +17,23 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
|
||||
}
|
||||
}
|
||||
|
||||
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
|
||||
auto cloneWithRealTypes = [](const Argument& a) {
|
||||
return a.cloneWithType(a.real_type());
|
||||
FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
|
||||
auto cloneWithRealTypes = [&](const Argument& a) {
|
||||
if (with_symint) {
|
||||
return a.cloneWithType(a.real_type());
|
||||
}
|
||||
// Don't use real type if it looks like a SymInt
|
||||
// NB: keep this in sync with unpackSymInt in KernelFunction_impl.h
|
||||
if (
|
||||
*a.real_type() == *getTypePtr<c10::SymInt>() ||
|
||||
*a.real_type() == *getTypePtr<c10::optional<c10::SymInt>>() ||
|
||||
*a.real_type() == *getTypePtr<c10::SymIntArrayRef>()
|
||||
) {
|
||||
// Keep the fake type
|
||||
return a.cloneWithType(a.type());
|
||||
} else {
|
||||
return a.cloneWithType(a.real_type());
|
||||
}
|
||||
};
|
||||
std::vector<Argument> new_arguments, new_returns;
|
||||
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
|
||||
|
||||
@ -474,7 +474,7 @@ struct TORCH_API FunctionSchema {
|
||||
FunctionSchema cloneWithRemappedTypes(
|
||||
const std::function<TypePtr(TypePtr)> type_map) const;
|
||||
|
||||
FunctionSchema cloneWithRealTypes() const;
|
||||
FunctionSchema cloneWithRealTypes(bool with_symint=true) const;
|
||||
|
||||
// Check that inputs have the correct types and appends any missing default
|
||||
// values.
|
||||
|
||||
@ -29,13 +29,12 @@ Tensor _empty_affine_quantized(
|
||||
}
|
||||
|
||||
Tensor empty_memory_format(
|
||||
const SymIntArrayRef sym_sizes,
|
||||
const IntArrayRef sizes,
|
||||
const c10::optional<ScalarType> dtype,
|
||||
const c10::optional<c10::Layout> layout,
|
||||
const c10::optional<Device> device,
|
||||
const c10::optional<bool> pin_memory,
|
||||
const optional<MemoryFormat> memory_format) {
|
||||
auto sizes = c10::asIntArrayRefSlow(sym_sizes);
|
||||
return convert(vTensor{
|
||||
api::context(),
|
||||
sizes,
|
||||
@ -56,12 +55,7 @@ Tensor empty_strided(
|
||||
const optional<Device> device,
|
||||
const optional<bool> pin_memory) {
|
||||
return empty_memory_format(
|
||||
c10::SymIntArrayRef::fromIntArrayRef(sizes),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
c10::MemoryFormat::Contiguous);
|
||||
sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
#ifdef USE_VULKAN_API
|
||||
|
||||
@ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) {
|
||||
return convert(v_output);
|
||||
}
|
||||
|
||||
inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) {
|
||||
auto shape = c10::asIntArrayRefSlow(sym_shape);
|
||||
inline Tensor view(const Tensor& self_arg, IntArrayRef shape) {
|
||||
return view_internal(self_arg, shape);
|
||||
}
|
||||
|
||||
|
||||
@ -439,12 +439,6 @@ std::tuple<Tensor, optional<int64_t>> view_batching_rule(
|
||||
return std::make_tuple(self_.view_symint(size_), 0);
|
||||
}
|
||||
|
||||
Tensor view_symint_decomposition(const Tensor& self,
|
||||
c10::SymIntArrayRef size) {
|
||||
return self.view( c10::asIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
|
||||
template <typename F, F Func>
|
||||
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
|
||||
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
|
||||
@ -512,14 +506,6 @@ std::tuple<Tensor, optional<int64_t>> diag_embed_batch_rule(const Tensor& self,
|
||||
return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0);
|
||||
}
|
||||
|
||||
// We need to write a real batching rule to fully support symint.
|
||||
// This requires symint variants of other operations, like `view`,
|
||||
// which don't exist yet.
|
||||
Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) {
|
||||
auto size = asIntArrayRefSlow(packed_size);
|
||||
return self.expand(size, implicit);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT(diag, diag_batch_rule);
|
||||
VMAP_SUPPORT(chunk, chunk_batching_rule);
|
||||
|
||||
@ -49,9 +49,9 @@ at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional<at::Sc
|
||||
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
|
||||
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
|
||||
}
|
||||
at::Tensor custom_empty_symint(c10::SymIntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
||||
at::Tensor custom_empty_symint(c10::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
||||
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
|
||||
return at::detail::empty_generic(c10::asIntArrayRefSlow(size), &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
|
||||
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
|
||||
}
|
||||
|
||||
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
|
||||
|
||||
@ -20,10 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
||||
return Tensor(std::move(tensor_impl));
|
||||
}
|
||||
|
||||
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
|
||||
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
test_int = 0;
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size));
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
|
||||
}
|
||||
|
||||
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
|
||||
|
||||
Reference in New Issue
Block a user