Make dispatcher registrations of SymInt functions backwards compatible (#84557)

Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.

With this change, I accept old type registrations transparently.  The
idea is in several parts:

- At the registration site, at compile time I have no idea whether or not
  if the function being registered has a SymInt schema or not.  So I
  must defer the exact compatibility check.  What I do instead is
  check if the function pointer registered to me has SymInt in the
  argument or not.  If it does, I assume it is new-style and ensure
  it is also registered to a special sym_ slot on KernelFunction.
  If not, it only goes in the conventional slot.

- At the dispatcher site, I know at compile time whether or not this
  is a SymInt function.  If it is, I check for a sym_ slot on the
  KernelFunction, and preferentially use that.  If no such slot
  exists, I then fall back to the regular slot... but I convert
  all SymInt arguments to int64_t arguments (doing assertions that
  no true symbolic integer was passed.)  I can skip this test entirely
  if the function doesn't have any SymInts in it; in that case I know
  that only the original slot could have been registered. Fortunately,
  both branches of the short circuit typecheck, so I didn't have to
  use SFINAE or if-constexpr to make it work; just a plain if statement
  that I expect the compiler to optimize away.

- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.

To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.

I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557
Approved by: https://github.com/wconstab
This commit is contained in:
Edward Z. Yang
2022-09-07 05:58:32 -07:00
committed by PyTorch MergeBot
parent ed46b9670e
commit 19e27b1556
12 changed files with 156 additions and 89 deletions

View File

@ -185,11 +185,6 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
// TODO: properly support this
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
}
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
@ -469,11 +464,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
// TODO: properly support this
return view_batching_rule(self, asIntArrayRefSlow(size));
}
Tensor view_as_complex_batching_rule(const Tensor& self) {
// guard against the user passing in a batch of scalar tensors with batch
// size equal to 2.
@ -1004,17 +994,6 @@ Tensor new_empty_batching_rule(
return physical_view.getPhysicalToLogicalMap().apply(result);
}
Tensor new_empty_symint_batching_rule(
const Tensor& self,
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// TODO: properly support this
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}
Tensor new_empty_strided_batching_rule(
const Tensor& self,
IntArrayRef size,
@ -1112,7 +1091,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_symint_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@ -1140,7 +1119,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_symint_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
// clamp operations
@ -1278,7 +1257,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("diagonal_backward", diagonal_backward_batching_rule);
// Tensor.new_* operators
m.impl("new_empty", new_empty_symint_batching_rule);
m.impl("new_empty", new_empty_batching_rule);
m.impl("new_empty_strided", new_empty_strided_batching_rule);
m.impl("new_zeros", new_zeros_batching_rule);

View File

@ -14,6 +14,40 @@ class OperatorHandle;
struct OperatorKernel;
class KernelFunction;
template <typename T>
using has_symint =
guts::disjunction<
std::is_same<c10::SymInt, std::decay_t<T>>,
std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
>;
template <typename T>
struct remove_symint {
using type = T;
};
template <>
struct remove_symint<c10::SymInt> {
using type = int64_t;
};
template <>
struct remove_symint<c10::SymIntArrayRef> {
using type = c10::IntArrayRef;
};
template <>
struct remove_symint<c10::optional<c10::SymInt>> {
using type = c10::optional<int64_t>;
};
template <typename T>
using fn_has_symint = typename guts::typelist::true_for_any_type<
has_symint,
typename guts::infer_function_traits<T>::type::parameter_types
>;
/**
* KernelFunction is similar to std::function but stores a kernel function.
* You can create a KernelFunction from a boxed or unboxed function/functor/lambda
@ -31,6 +65,7 @@ public:
// Fast path for dispatch to allow not touching the boxed kernel in
// the common case where unboxed is available.
bool isValidUnboxed() const;
bool isValidSymUnboxed() const;
bool isValid() const;
bool isFallthrough() const;
@ -182,13 +217,16 @@ private:
explicit KernelFunction(
std::unique_ptr<OperatorKernel> functor,
InternalBoxedKernelFunction* boxed_kernel_func,
void* unboxed_kernel_func);
void* unboxed_kernel_func,
void* sym_unboxed_kernel_func);
explicit KernelFunction(
BoxedKernel boxed_fn,
void* unboxed_kernel_func);
void* unboxed_kernel_func,
void* sym_unboxed_kernel_func);
BoxedKernel boxed_kernel_func_;
void* unboxed_kernel_func_;
void* sym_unboxed_kernel_func_;
};
}

View File

@ -8,22 +8,29 @@ namespace c10 {
inline KernelFunction::KernelFunction()
: boxed_kernel_func_()
, unboxed_kernel_func_(nullptr)
, sym_unboxed_kernel_func_(nullptr)
{}
inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func)
inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
: boxed_kernel_func_(std::move(functor), boxed_kernel_func)
, unboxed_kernel_func_(unboxed_kernel_func)
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
{}
inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func)
inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
: boxed_kernel_func_(std::move(boxed_fn))
, unboxed_kernel_func_(unboxed_kernel_func)
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
{}
inline bool KernelFunction::isValidUnboxed() const {
return unboxed_kernel_func_ != nullptr;
}
inline bool KernelFunction::isValidSymUnboxed() const {
return sym_unboxed_kernel_func_ != nullptr;
}
inline bool KernelFunction::isValid() const {
return boxed_kernel_func_.isValid();
}
@ -43,16 +50,52 @@ inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKerne
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
}
// This template requires you to explicitly specify the argument you want to
// forward; it doesn't work if you try to deduce it
template <typename T>
inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
template <>
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
return x.expect_int();
}
template <>
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) {
return c10::asIntArrayRefSlow(x);
}
template <>
inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
}
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
// This should get inlined by compiler
if (guts::disjunction<has_symint<Args>...>::value) {
if (sym_unboxed_kernel_func_ != nullptr) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
}
if (unboxed_kernel_func_ != nullptr) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>(
unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...);
}
} else {
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
}
}
return impl::BoxedKernelWrapper<Return(Args...)>::call(
@ -102,10 +145,14 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<Ope
#endif
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
return KernelFunction(
std::move(kernelFunctor),
&impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call,
reinterpret_cast<void*>(&impl::wrap_kernel_functor_unboxed<KernelFunctor>::call)
is_symint ? nullptr : void_unboxed_fn,
is_symint ? void_unboxed_fn : nullptr
);
}

View File

@ -26,6 +26,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, kernels_()
, cpp_signature_()
, sym_cpp_signature_()
, is_observed_(ObservedOperators::isObserved(name_))
{
// Pick up any backend fallbacks that were registered prior to this
@ -34,12 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
}
namespace {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) {
// TODO: figure out if we can just directly save real schema at def time
c10::optional<std::string> schema_difference = findSchemaDifferences(
from_def.cloneWithRealTypes(),
inferred.cloneWithRealTypes()
);
FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed());
FunctionSchema inferred = inferred_.cloneWithRealTypes();
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
if (schema_difference.has_value()) {
TORCH_CHECK(false,
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
@ -64,12 +64,24 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
return kernel;
}
void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const {
if (has_symint) {
if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
reportSignatureError(call_signature, *sym_cpp_signature_);
}
} else {
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
reportSignatureError(call_signature, *cpp_signature_);
}
}
}
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
TORCH_INTERNAL_ASSERT(!schema_.has_value());
for (const auto& kernel : kernels_) {
for (const auto &j : kernel.second) {
if (j.inferred_function_schema != nullptr) {
checkSchema(name_, schema, debug, *j.inferred_function_schema, j.debug);
checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
}
}
}
@ -103,25 +115,26 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
// which means if you could validly change the type of a cpp_signature, then
// that would also invalidate the old TypedOperatorHandles.
if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_;
if (local_cpp_signature.has_value()) {
TORCH_CHECK(*cpp_signature == local_cpp_signature->signature,
"\nMismatch in kernel C++ signatures\n",
" operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" kernel 1: ", cpp_signature_->signature.name(), "\n",
" dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
" ", cpp_signature_->debug, "\n",
" kernel 1: ", local_cpp_signature->signature.name(), "\n",
" dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n",
" ", local_cpp_signature->debug, "\n",
" kernel 2: ", cpp_signature->name(), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" ", debug, "\n"
);
} else {
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
}
}
if (schema_ && inferred_function_schema) {
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug);
}
// Add the kernel to the kernels list,
@ -138,7 +151,7 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n",
" new kernel: ", debug
);
}
@ -471,13 +484,13 @@ std::string OperatorEntry::listAllDispatchKeys() const {
return str.str();
}
void OperatorEntry::reportSignatureError(const CppSignature call_signature) const {
void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const {
TORCH_CHECK(false,
"\nTried to access or call an operator with a wrong signature.\n",
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n",
" correct signature: ", cpp_signature_->signature.name(), "\n",
" ", cpp_signature_->debug, "\n",
" correct signature: ", saved_signature.signature.name(), "\n",
" ", saved_signature.debug, "\n",
" accessed/called as: ", call_signature.name(), "\n",
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
"Please make sure that the function signature matches the signature in the operator registration call."

View File

@ -163,14 +163,10 @@ public:
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template<class FuncType>
inline void assertSignatureIsCorrect() {
assertSignatureIsCorrect(CppSignature::make<FuncType>());
assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
}
void assertSignatureIsCorrect(const CppSignature call_signature) {
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
reportSignatureError(call_signature);
}
}
void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const;
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
@ -280,11 +276,12 @@ private:
c10::optional<DispatchKey> dispatch_key;
};
c10::optional<CppSignatureWithDebug> cpp_signature_;
c10::optional<CppSignatureWithDebug> sym_cpp_signature_;
// Whether this operator needs to be observed with RecordFunction
const bool is_observed_;
[[noreturn]] void reportSignatureError(CppSignature call_signature) const;
[[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const;
const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
const c10::Dispatcher& dispatcher, DispatchKey dispatch_key

View File

@ -17,9 +17,23 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
}
}
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
auto cloneWithRealTypes = [](const Argument& a) {
return a.cloneWithType(a.real_type());
FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
auto cloneWithRealTypes = [&](const Argument& a) {
if (with_symint) {
return a.cloneWithType(a.real_type());
}
// Don't use real type if it looks like a SymInt
// NB: keep this in sync with unpackSymInt in KernelFunction_impl.h
if (
*a.real_type() == *getTypePtr<c10::SymInt>() ||
*a.real_type() == *getTypePtr<c10::optional<c10::SymInt>>() ||
*a.real_type() == *getTypePtr<c10::SymIntArrayRef>()
) {
// Keep the fake type
return a.cloneWithType(a.type());
} else {
return a.cloneWithType(a.real_type());
}
};
std::vector<Argument> new_arguments, new_returns;
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);

View File

@ -474,7 +474,7 @@ struct TORCH_API FunctionSchema {
FunctionSchema cloneWithRemappedTypes(
const std::function<TypePtr(TypePtr)> type_map) const;
FunctionSchema cloneWithRealTypes() const;
FunctionSchema cloneWithRealTypes(bool with_symint=true) const;
// Check that inputs have the correct types and appends any missing default
// values.

View File

@ -29,13 +29,12 @@ Tensor _empty_affine_quantized(
}
Tensor empty_memory_format(
const SymIntArrayRef sym_sizes,
const IntArrayRef sizes,
const c10::optional<ScalarType> dtype,
const c10::optional<c10::Layout> layout,
const c10::optional<Device> device,
const c10::optional<bool> pin_memory,
const optional<MemoryFormat> memory_format) {
auto sizes = c10::asIntArrayRefSlow(sym_sizes);
return convert(vTensor{
api::context(),
sizes,
@ -56,12 +55,7 @@ Tensor empty_strided(
const optional<Device> device,
const optional<bool> pin_memory) {
return empty_memory_format(
c10::SymIntArrayRef::fromIntArrayRef(sizes),
dtype,
layout,
device,
pin_memory,
c10::MemoryFormat::Contiguous);
sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
}
#ifdef USE_VULKAN_API

View File

@ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) {
return convert(v_output);
}
inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) {
auto shape = c10::asIntArrayRefSlow(sym_shape);
inline Tensor view(const Tensor& self_arg, IntArrayRef shape) {
return view_internal(self_arg, shape);
}

View File

@ -439,12 +439,6 @@ std::tuple<Tensor, optional<int64_t>> view_batching_rule(
return std::make_tuple(self_.view_symint(size_), 0);
}
Tensor view_symint_decomposition(const Tensor& self,
c10::SymIntArrayRef size) {
return self.view( c10::asIntArrayRefSlow(size));
}
template <typename F, F Func>
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
@ -512,14 +506,6 @@ std::tuple<Tensor, optional<int64_t>> diag_embed_batch_rule(const Tensor& self,
return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0);
}
// We need to write a real batching rule to fully support symint.
// This requires symint variants of other operations, like `view`,
// which don't exist yet.
Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) {
auto size = asIntArrayRefSlow(packed_size);
return self.expand(size, implicit);
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(diag, diag_batch_rule);
VMAP_SUPPORT(chunk, chunk_batching_rule);

View File

@ -49,9 +49,9 @@ at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional<at::Sc
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
}
at::Tensor custom_empty_symint(c10::SymIntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
at::Tensor custom_empty_symint(c10::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(c10::asIntArrayRefSlow(size), &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
}
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {

View File

@ -20,10 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
return Tensor(std::move(tensor_impl));
}
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
test_int = 0;
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size));
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
}
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {