mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
functionalization <> LTC integration (take 3) (#80251)
new PR for https://github.com/pytorch/pytorch/pull/75527. It looks like there's a bug in the windows CI scripts that was causing flaky failures, that disappear when I create a new PR. example failure: https://github.com/pytorch/pytorch/runs/6999272635?check_suite_focus=true Pull Request resolved: https://github.com/pytorch/pytorch/pull/80251 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
33761c80d2
commit
c2d395cf8e
@ -22,6 +22,9 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
storage_access_should_throw_ = false;
|
||||
// In general, the sizes/stride metadata on a tensor can change as it is mutated,
|
||||
// and these changes need to be reflected in the metadata of the wrapper.
|
||||
set_allow_tensor_metadata_change(true);
|
||||
key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
|
||||
// All of the keys corresponding to functorch transforms should not be copied over.
|
||||
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
|
||||
@ -180,8 +183,12 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
|
||||
value_ = other;
|
||||
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
|
||||
// We need to propagate that metadata mutation to the wrapper (new size).
|
||||
set_sizes_and_strides(value_.sizes(), value_.strides());
|
||||
set_storage_offset(value_.storage_offset());
|
||||
if (sizes() != value_.sizes() || strides() != value_.strides()) {
|
||||
set_sizes_and_strides(value_.sizes(), value_.strides());
|
||||
}
|
||||
if (storage_offset() != value_.storage_offset()) {
|
||||
set_storage_offset(value_.storage_offset());
|
||||
}
|
||||
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
||||
value_ = value_.to(c10::TensorOptions().dtype(dtype()).layout(layout()));
|
||||
}
|
||||
@ -260,6 +267,44 @@ const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
|
||||
return "FunctionalTensorWrapper";
|
||||
}
|
||||
|
||||
template <typename VariableVersion>
|
||||
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
|
||||
VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
if (key_set_.has(DispatchKey::Python) &&
|
||||
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
||||
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
|
||||
if (r) {
|
||||
r->set_version_counter(std::forward<VariableVersion>(version_counter));
|
||||
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
|
||||
copy_tensor_metadata(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/std::forward<VariableVersion>(version_counter),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
return impl;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
return shallow_copy_and_detach_core(
|
||||
version_counter, allow_tensor_metadata_change);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
return shallow_copy_and_detach_core(
|
||||
std::move(version_counter), allow_tensor_metadata_change);
|
||||
}
|
||||
|
||||
at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
|
||||
return value_.unsafeGetTensorImpl()->sizes();
|
||||
}
|
||||
@ -275,6 +320,9 @@ int64_t FunctionalTensorWrapper::numel_custom() const {
|
||||
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
||||
return value_.unsafeGetTensorImpl()->is_contiguous();
|
||||
}
|
||||
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes() const {
|
||||
return value_.unsafeGetTensorImpl()->sym_sizes();
|
||||
}
|
||||
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
|
||||
return value_.unsafeGetTensorImpl()->sym_sizes();
|
||||
}
|
||||
@ -329,7 +377,7 @@ std::vector<Tensor> to_functional_tensor(const TensorList& t_list) {
|
||||
|
||||
Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
|
||||
// Note [Wrapped Numbers <> Functionalization]
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
return tensor;
|
||||
}
|
||||
if (isFunctionalTensor(tensor)) {
|
||||
@ -454,41 +502,59 @@ bool isFunctionalTensor(const c10::optional<Tensor>& t) {
|
||||
|
||||
bool isFunctionalTensor(const c10::List<Tensor>& t_list) {
|
||||
if (t_list.size() == 0) return false;
|
||||
bool any_functional = isFunctionalTensor(t_list[0]);
|
||||
for (const auto i : c10::irange(1, t_list.size())) {
|
||||
auto curr_functional = isFunctionalTensor(t_list[i]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
curr_functional == any_functional,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
auto functional_count = 0;
|
||||
auto nonfunctional_count = 0;
|
||||
for (const auto i : c10::irange(t_list.size())) {
|
||||
if (!t_list[i].defined()) continue;
|
||||
if (isFunctionalTensor(t_list[i])) {
|
||||
++functional_count;
|
||||
} else {
|
||||
++nonfunctional_count;
|
||||
}
|
||||
}
|
||||
return any_functional;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
functional_count == 0 || nonfunctional_count == 0,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
return functional_count > 0;
|
||||
}
|
||||
|
||||
bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
|
||||
if (t_list.size() == 0) return false;
|
||||
bool any_functional = isFunctionalTensor(t_list[0]);
|
||||
for (const auto i : c10::irange(1, t_list.size())) {
|
||||
auto curr_functional = isFunctionalTensor(t_list[i]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
curr_functional == any_functional,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
auto functional_count = 0;
|
||||
auto nonfunctional_count = 0;
|
||||
for (const auto i : c10::irange(t_list.size())) {
|
||||
if (!t_list[i].has_value() || !t_list[i]->defined()) continue;
|
||||
if (isFunctionalTensor(t_list[i])) {
|
||||
++functional_count;
|
||||
} else {
|
||||
++nonfunctional_count;
|
||||
}
|
||||
}
|
||||
return any_functional;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
functional_count == 0 || nonfunctional_count == 0,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
return functional_count > 0;
|
||||
}
|
||||
|
||||
bool isFunctionalTensor(const c10::ArrayRef<Tensor> t_list) {
|
||||
if (t_list.size() == 0) return false;
|
||||
bool any_functional = isFunctionalTensor(t_list[0]);
|
||||
for (const auto i : c10::irange(1, t_list.size())) {
|
||||
auto curr_functional = isFunctionalTensor(t_list[i]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
curr_functional == any_functional,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
auto functional_count = 0;
|
||||
auto nonfunctional_count = 0;
|
||||
for (const auto i : c10::irange(t_list.size())) {
|
||||
if (!t_list[i].defined()) continue;
|
||||
if (isFunctionalTensor(t_list[i])) {
|
||||
++functional_count;
|
||||
} else {
|
||||
++nonfunctional_count;
|
||||
}
|
||||
}
|
||||
return any_functional;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
functional_count == 0 || nonfunctional_count == 0,
|
||||
"Functionalization encountered a list of tensors where some are functional",
|
||||
"and some are not, which is not currently unsupported.");
|
||||
return functional_count > 0;
|
||||
}
|
||||
|
||||
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
|
||||
@ -552,5 +618,93 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) {
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
|
||||
// Given an **out-of-place** op that might internally call view/inplace ops,
|
||||
// This function will "functionalize" it.
|
||||
// That is, it will call the operator, but removing any intermediate views/mutations
|
||||
// that are performed inside of it.
|
||||
// This is useful for LTC/XLA, which would like to re-use some of our composite kernels
|
||||
// from pytorch core but not have to worry about the view ops that they might call.
|
||||
// e.g. at::block_diag
|
||||
void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
const auto arguments_begin = stack->size() - num_arguments;
|
||||
auto arguments = torch::jit::last(stack, num_arguments);
|
||||
|
||||
// Wrap all tensor-like inputs into FunctionalTensorWrappers.
|
||||
// When we re-invoke the dispatcher, this will automatically enable the functionalization pass.
|
||||
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (ivalue.isTensor()) {
|
||||
auto t = ivalue.toTensor();
|
||||
if (t.defined()) {
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t),
|
||||
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
|
||||
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
|
||||
(*stack)[arguments_begin + idx] = t_new;
|
||||
}
|
||||
} else if (ivalue.isTensorList()) {
|
||||
auto tensors = ivalue.toTensorList();
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors),
|
||||
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
|
||||
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
|
||||
(*stack)[arguments_begin + idx] = t_new;
|
||||
} else if (ivalue.isOptionalTensorList()) {
|
||||
auto opt_tensors = ivalue.toOptionalTensorList();
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors),
|
||||
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
|
||||
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
|
||||
(*stack)[arguments_begin + idx] = t_new;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap
|
||||
// the output in a functional tensor based on TLS.
|
||||
// In this code, we're re-entrantly entering functionalization in the same call-stack,
|
||||
// so we need to manually fix up TLS as if it hadn't already been called.
|
||||
auto curr_tls = c10::impl::tls_local_dispatch_key_set();
|
||||
auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
|
||||
tls_reenable_functionalize.set_included(curr_tls.included_);
|
||||
tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
|
||||
c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
|
||||
// So, we should probably provide a way to directly call a kernel registered to
|
||||
// the `CompositeExplicitAutograd` key.
|
||||
// We can't do that today, so this should be a reasonably good proxy
|
||||
// (It won't work in cases where an op has both a CompositeExplicitAutograd kernel
|
||||
// AND a dedicated meta kernel, but that probably shouldn't ever happen).
|
||||
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack);
|
||||
}
|
||||
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto returns_begin = stack->size() - num_returns;
|
||||
auto returns = torch::jit::last(stack, num_returns);
|
||||
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto& ivalue = returns[idx];
|
||||
if (ivalue.isTensor()) {
|
||||
auto t = ivalue.toTensor();
|
||||
if (!t.defined()) continue;
|
||||
at::functionalization::impl::sync(t);
|
||||
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
|
||||
(*stack)[returns_begin + idx] = t_new;
|
||||
} else if (ivalue.isTensorList()) {
|
||||
auto tensors = ivalue.toTensorList();
|
||||
at::functionalization::impl::sync(tensors);
|
||||
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
|
||||
(*stack)[returns_begin + idx] = t_new;
|
||||
} else if (ivalue.isOptionalTensorList()) {
|
||||
auto opt_tensors = ivalue.toOptionalTensorList();
|
||||
at::functionalization::impl::sync(opt_tensors);
|
||||
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
|
||||
(*stack)[returns_begin + idx] = t_new;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace functionalization
|
||||
} // namespace at
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include <ATen/ArrayRef.h>
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
#include <ATen/core/List.h>
|
||||
#include <ATen/core/boxing/KernelFunction.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include <c10/core/DispatchKey.h>
|
||||
|
||||
@ -120,6 +122,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
// See Note[resize_() in functionalization pass]
|
||||
void maybe_replace_storage(const Tensor& other);
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override;
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const override;
|
||||
|
||||
~FunctionalTensorWrapper() override = default;
|
||||
|
||||
// FunctionalTensorWrapper overrides all custom size/stride function,
|
||||
@ -130,6 +140,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
int64_t dim_custom() const override;
|
||||
int64_t numel_custom() const override;
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
c10::SymIntArrayRef sym_sizes() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
|
||||
private:
|
||||
@ -137,6 +148,16 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
void set_constructor_metadata();
|
||||
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
||||
|
||||
// This is used to re-implement shallow_copy_and_detach for
|
||||
// FunctionalTensorWrapper. The implementation is identical, but we just need
|
||||
// to return a subclass instead of a plain TensorImpl.
|
||||
// TODO: maybe it's possible to arrange for that to happen automatically
|
||||
// without an override here?
|
||||
template <typename VariableVersion>
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
||||
VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const;
|
||||
|
||||
// Note that value is not taken by reference: internally, the wrapper will
|
||||
// change the value tensor that it points to over time.
|
||||
Tensor value_;
|
||||
@ -251,5 +272,37 @@ class TORCH_API FunctionalizationReapplyViewsGuard {
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// Helper function to call an out-of-place composite aten kernel that may use
|
||||
// mutations / views internally, and functionalize them.
|
||||
TORCH_API void functionalize_op_helper(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
template <class Op, class ReturnType, class... ParameterTypes>
|
||||
struct _functionalize_aten_op final {};
|
||||
|
||||
template <class Op, class ReturnType, class... ParameterTypes>
|
||||
struct _functionalize_aten_op<Op, ReturnType(ParameterTypes...)> final {
|
||||
static ReturnType call(ParameterTypes... args) {
|
||||
auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow(
|
||||
(const char*)Op::name, (const char*)Op::overload_name)
|
||||
.typed<ReturnType(ParameterTypes...)>();
|
||||
|
||||
return c10::impl::BoxedKernelWrapper<ReturnType(ParameterTypes...)>::call(
|
||||
c10::KernelFunction::make_boxed_function<functionalize_op_helper>,
|
||||
nullptr,
|
||||
op,
|
||||
// BoxedKernelWrapper knows to ignore this keyset argument,
|
||||
// because functionalize_op_helper doesn't take in a DispatchKeySet
|
||||
c10::DispatchKeySet(),
|
||||
args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Op>
|
||||
using functionalize_aten_op = _functionalize_aten_op<Op, typename Op::schema>;
|
||||
|
||||
} // namespace functionalization
|
||||
} // namespace at
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_to_copy.h>
|
||||
#include <ATen/ops/to_native.h>
|
||||
#include <ATen/ops/resize.h>
|
||||
#include <ATen/ops/as_strided.h>
|
||||
@ -32,7 +33,7 @@ namespace {
|
||||
if (ivalue.isTensor()) {
|
||||
any_tensor_inputs = true;
|
||||
auto t = ivalue.toTensor();
|
||||
if (at::functionalization::impl::isFunctionalTensor(t)) {
|
||||
if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
|
||||
any_functional_inputs = true;
|
||||
at::functionalization::impl::sync(t);
|
||||
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
|
||||
@ -73,6 +74,7 @@ namespace {
|
||||
const auto& ivalue = returns[idx];
|
||||
if (ivalue.isTensor() && should_wrap_outputs) {
|
||||
auto t = ivalue.toTensor();
|
||||
if (!t.defined()) continue;
|
||||
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
|
||||
(*stack)[returns_begin + idx] = t_new;
|
||||
} else if (ivalue.isTensorList() && should_wrap_outputs) {
|
||||
@ -171,6 +173,48 @@ at::Tensor lift_functionalize(const at::Tensor & self) {
|
||||
return at::functionalization::impl::to_functional_tensor(self);
|
||||
}
|
||||
|
||||
bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
|
||||
// If the target device is empty, then the output tensor should be on the same device as the input
|
||||
auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
|
||||
return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
|
||||
}
|
||||
|
||||
// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
|
||||
// We should probably get rid of this though.
|
||||
at::Tensor _to_copy_functionalize(
|
||||
const at::Tensor & self,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
bool non_blocking,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
at::Tensor self_;
|
||||
if (at::functionalization::impl::isFunctionalTensor(self)) {
|
||||
// sync any pending updates
|
||||
at::functionalization::impl::sync(self);
|
||||
// pass the unwrapped tensor to the backend
|
||||
self_ = at::functionalization::impl::from_functional_tensor(self);
|
||||
} else {
|
||||
self_ = self;
|
||||
}
|
||||
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
|
||||
|
||||
// Special case: if the Functionalize key is not in TLS, we assume that we're running
|
||||
// on a lazy backend (LTC).
|
||||
// In that case, if we're copying to a non-functionalize-enabled device,
|
||||
// then the functionalization pass should "end". We need to sync any updates on the input
|
||||
// tensor, but we shouldn't wrap the output.
|
||||
if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
|
||||
if (!device_opted_into_functionalization(self.device(), device)) {
|
||||
return out;
|
||||
}
|
||||
}
|
||||
return at::functionalization::impl::to_functional_tensor(out);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
|
||||
}
|
||||
@ -178,4 +222,5 @@ TORCH_LIBRARY_IMPL(_, Functionalize, m) {
|
||||
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
|
||||
m.impl("resize_", TORCH_FN(resize__functionalization));
|
||||
m.impl("lift", TORCH_FN(lift_functionalize));
|
||||
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
|
||||
}
|
||||
|
@ -20,5 +20,6 @@ bool to_will_alias(
|
||||
Tensor to_meta(const Tensor& tensor);
|
||||
c10::optional<Tensor> to_meta(const c10::optional<Tensor>& tensor);
|
||||
std::vector<Tensor> to_meta(const at::TensorList& t_list);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -765,7 +765,7 @@
|
||||
device_guard: False
|
||||
tags: inplace_view
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: as_strided_
|
||||
CompositeExplicitAutogradNonFunctional: as_strided_
|
||||
|
||||
- func: asin(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -1976,7 +1976,7 @@
|
||||
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: new_empty_strided
|
||||
CompositeExplicitAutogradNonFunctional: new_empty_strided
|
||||
|
||||
- func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
@ -3027,7 +3027,8 @@
|
||||
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: logsumexp_out
|
||||
# calls squeeze
|
||||
CompositeExplicitAutogradNonFunctional: logsumexp_out
|
||||
|
||||
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -3637,12 +3638,12 @@
|
||||
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
|
||||
dispatch:
|
||||
CPU: pixel_shuffle_cpu
|
||||
CompositeExplicitAutograd: math_pixel_shuffle
|
||||
CompositeExplicitAutogradNonFunctional: math_pixel_shuffle
|
||||
|
||||
- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
|
||||
dispatch:
|
||||
CPU: pixel_unshuffle_cpu
|
||||
CompositeExplicitAutograd: math_pixel_unshuffle
|
||||
CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle
|
||||
|
||||
- func: channel_shuffle(Tensor self, int groups) -> Tensor
|
||||
dispatch:
|
||||
@ -4088,7 +4089,7 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: select_backward
|
||||
CompositeExplicitAutogradNonFunctional: select_backward
|
||||
|
||||
- func: selu(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -4920,7 +4921,8 @@
|
||||
|
||||
- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _trilinear
|
||||
# calls unsqueeze
|
||||
CompositeExplicitAutogradNonFunctional: _trilinear
|
||||
|
||||
- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
|
||||
|
||||
@ -11648,13 +11650,15 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: linalg_inv_ex
|
||||
# calls transpose_
|
||||
CompositeExplicitAutogradNonFunctional: linalg_inv_ex
|
||||
|
||||
- func: linalg_inv_ex.inverse(Tensor self, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: linalg_inv_ex_out
|
||||
# calls transpose_
|
||||
CompositeExplicitAutogradNonFunctional: linalg_inv_ex_out
|
||||
|
||||
- func: linalg_inv(Tensor self) -> Tensor
|
||||
python_module: linalg
|
||||
@ -11766,7 +11770,9 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: linalg_pinv
|
||||
# calls svd, which calls mH() (view op)
|
||||
# also calls narrow()
|
||||
CompositeExplicitAutogradNonFunctional: linalg_pinv
|
||||
|
||||
- func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
@ -12018,7 +12024,7 @@
|
||||
- func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: expand_copy_SymInt
|
||||
CompositeExplicitAutograd: expand_copy_SymInt
|
||||
tags: view_copy
|
||||
|
||||
- func: permute_copy(Tensor self, int[] dims) -> Tensor
|
||||
|
@ -143,40 +143,65 @@ full_codegen:
|
||||
- upsample_nearest2d_backward
|
||||
- zero.functional
|
||||
- narrow_copy.SymInt
|
||||
- alias_copy
|
||||
- as_strided_copy
|
||||
- diagonal_copy
|
||||
- expand_copy
|
||||
- permute_copy
|
||||
- _reshape_alias_copy
|
||||
- select_copy.int
|
||||
- detach_copy
|
||||
- slice_copy.Tensor
|
||||
# Not implemented yet because LTC codegen doesn't currently work
|
||||
# for ops that return lists of tensors.
|
||||
#- split_copy.Tensor
|
||||
#- split_with_sizes_copy
|
||||
#- unbind_copy.int
|
||||
- squeeze_copy
|
||||
- squeeze_copy.dim
|
||||
- t_copy
|
||||
- transpose_copy.int
|
||||
- unsqueeze_copy
|
||||
- view_copy
|
||||
- view_copy.dtype
|
||||
- unfold_copy
|
||||
- select_scatter
|
||||
- slice_scatter
|
||||
- diagonal_scatter
|
||||
- as_strided_scatter
|
||||
supported:
|
||||
- as_strided
|
||||
- as_strided_
|
||||
- clone
|
||||
- _copy_from
|
||||
- _copy_from_and_resize
|
||||
- diagonal
|
||||
- empty.memory_format
|
||||
- empty_strided
|
||||
- expand
|
||||
- fill_.Scalar
|
||||
- narrow
|
||||
- normal_
|
||||
- max_pool3d_with_indices
|
||||
- max_pool3d_with_indices_backward
|
||||
- permute
|
||||
- select.int
|
||||
- slice.Tensor
|
||||
- squeeze
|
||||
- squeeze.dim
|
||||
- squeeze_
|
||||
- squeeze_.dim
|
||||
- t
|
||||
- t_
|
||||
- _to_copy
|
||||
- transpose.int
|
||||
- transpose_
|
||||
- unsqueeze
|
||||
- unsqueeze_
|
||||
- view
|
||||
- alias
|
||||
- _unsafe_view
|
||||
- lift
|
||||
# Below are all operators that are "composite" in core,
|
||||
# but require us to explicitly re-enable functionalization in order to use them.
|
||||
# Why? These operators are all CompositeExplicitAutograd, which mean that they run
|
||||
# after functionalization,
|
||||
# but their implementations call view operators (which we need to functionalize away).
|
||||
- block_diag
|
||||
- diagonal_backward
|
||||
- slice_backward
|
||||
- new_empty_strided
|
||||
- narrow_copy
|
||||
- pixel_shuffle
|
||||
- pixel_unshuffle
|
||||
- select_backward
|
||||
- _trilinear
|
||||
- linalg_inv_ex
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- logsumexp.out
|
||||
autograd:
|
||||
- max_pool3d
|
||||
- native_group_norm
|
||||
|
||||
# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
|
||||
non_native:
|
||||
|
@ -54,6 +54,8 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
|
||||
|
||||
${CompositeViewCopyKernel_Definitions}
|
||||
|
||||
${SymIntViewCopyKernel_Definitions}
|
||||
|
||||
${GeneratedCompositeFunctional_Definitions}
|
||||
|
||||
${GeneratedCompositeOut_Definitions}
|
||||
|
@ -2,6 +2,10 @@
|
||||
${includes}
|
||||
${native_functions_include}
|
||||
|
||||
namespace {
|
||||
${helper_fns}
|
||||
} // namespace
|
||||
|
||||
${namespace_prologue}
|
||||
|
||||
${native_function_definitions}
|
||||
|
@ -217,7 +217,6 @@ enum class DispatchKey : uint16_t {
|
||||
// Out-of-core key for Fake Tensor in torchdistx.
|
||||
// See https://pytorch.org/torchdistx/latest/fake_tensor.html
|
||||
Fake,
|
||||
|
||||
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
|
||||
// is to insert code after the "autograd subsystem" runs, so this key should
|
||||
// be directly after ADInplaceOrView and all of the autograd keys.
|
||||
|
@ -54,10 +54,25 @@ def init_lists():
|
||||
'pow', # incorrect results
|
||||
'addcdiv', # incorrect results (on CI not locally?)
|
||||
])
|
||||
# The following ops all show up directly in ts_native_functions.yaml,
|
||||
# but run functionalized versions of the composite kernels in core.
|
||||
# This means that we don't expect the ops to show directly in the LTC metrics.
|
||||
FUNCTIONAL_DECOMPOSE_LIST = set([
|
||||
'block_diag',
|
||||
'new_empty_strided',
|
||||
'narrow_copy',
|
||||
'pixel_shuffle',
|
||||
'pixel_unshuffle',
|
||||
'select_backward',
|
||||
'_trilinear',
|
||||
'linalg_inv_ex',
|
||||
'linalg_pinv.atol_rtol_tensor',
|
||||
'logsumexp',
|
||||
])
|
||||
|
||||
return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST)
|
||||
return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST)
|
||||
|
||||
(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) = init_lists()
|
||||
(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST) = init_lists()
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
@ -96,9 +111,36 @@ class TestLazyTensor(JitTestCase):
|
||||
torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu())
|
||||
torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu())
|
||||
|
||||
def test_view_mark_step_preserved(self):
|
||||
test_device = get_test_device()
|
||||
inp = torch.rand(4, device=test_device)
|
||||
inp_lazy = clone_move(inp)
|
||||
|
||||
def foo(x, *, mark_step):
|
||||
y = x.view(2, 2)
|
||||
y.add_(1)
|
||||
z = x + x
|
||||
|
||||
if mark_step:
|
||||
torch._lazy.mark_step()
|
||||
|
||||
# y and x should contiue to be aliased after the mark_step call.
|
||||
y.add_(1)
|
||||
return x
|
||||
|
||||
|
||||
out_ref = foo(inp, mark_step=False)
|
||||
out = foo(inp_lazy, mark_step=True)
|
||||
# out will have some pending mutations, which will be synced by the .cpu() call.
|
||||
torch.testing.assert_close(out_ref.cpu(), out.cpu())
|
||||
|
||||
class TestLazyOpInfo(TestCase):
|
||||
|
||||
@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST], allowed_dtypes=(torch.float,))
|
||||
@ops([op for op in op_db
|
||||
if op.name in LAZY_OPS_LIST
|
||||
and op.name not in SKIP_RUNTIME_ERROR_LIST
|
||||
and op.name not in FUNCTIONAL_DECOMPOSE_LIST
|
||||
], allowed_dtypes=(torch.float,))
|
||||
def test_dispatched_to_lazy(self, device, dtype, op):
|
||||
def get_name(op):
|
||||
l = [op.name]
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: tests"]
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@ -475,6 +474,8 @@ class TestViewOps(TestCase):
|
||||
v[0] = 0
|
||||
self.assertEqual(t[2, 0], v[0])
|
||||
|
||||
# Lazy hasn't implemented unbind yet.
|
||||
@onlyNativeDeviceTypes
|
||||
def test_unbind_view(self, device) -> None:
|
||||
t = torch.zeros((5, 5), device=device)
|
||||
tup = torch.unbind(t)
|
||||
@ -505,6 +506,9 @@ class TestViewOps(TestCase):
|
||||
stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True)
|
||||
|
||||
# TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken,
|
||||
# causing asserts to trigger.
|
||||
@onlyNativeDeviceTypes
|
||||
def test_expand_view(self, device) -> None:
|
||||
t = torch.ones((5, 1), device=device)
|
||||
v = t.expand(5, 5)
|
||||
@ -718,6 +722,8 @@ class TestViewOps(TestCase):
|
||||
self.assertTrue(s is t)
|
||||
|
||||
@skipMeta
|
||||
# self.is_view_of reports false positives for lazy
|
||||
@onlyNativeDeviceTypes
|
||||
def test_contiguous_nonview(self, device):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
nv = t.t().contiguous()
|
||||
@ -744,6 +750,8 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(t[1, 1], v[6])
|
||||
|
||||
@skipMeta
|
||||
# self.is_view_of reports false positives for lazy
|
||||
@onlyNativeDeviceTypes
|
||||
def test_reshape_nonview(self, device):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
nv = torch.reshape(t.t(), (25,))
|
||||
@ -752,6 +760,9 @@ class TestViewOps(TestCase):
|
||||
nv[6] = 0
|
||||
self.assertNotEqual(t[1, 1], nv[6])
|
||||
|
||||
# This test use as_strided to construct a tensor with overlapping memory,
|
||||
# which is not handled by the functionalization pass.
|
||||
@onlyNativeDeviceTypes
|
||||
def test_flatten_view(self, device):
|
||||
def test_writes_propagate(t, v):
|
||||
idx_t = (0,) * t.ndim
|
||||
@ -1820,7 +1831,7 @@ class TestOldViewOps(TestCase):
|
||||
t.crow_indices()
|
||||
t.col_indices()
|
||||
|
||||
instantiate_device_type_tests(TestViewOps, globals())
|
||||
instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True)
|
||||
instantiate_device_type_tests(TestOldViewOps, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -50,10 +50,12 @@
|
||||
#include <torch/csrc/lazy/core/shape_inference.h>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
@ -599,6 +601,17 @@ std::vector<Shape> compute_shape_mean(
|
||||
return {Shape(self.scalar_type(), {})};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_new_empty_strided(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_mv(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& vec) {
|
||||
@ -949,6 +962,12 @@ std::vector<Shape> compute_shape__to_copy(
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
TORCH_API std::vector<Shape> compute_shape_clone(
|
||||
const at::Tensor& self,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
|
||||
TORCH_CHECK(tensors.size() > 0, "stack expects a non-empty TensorList");
|
||||
auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1);
|
||||
@ -1108,6 +1127,106 @@ std::vector<Shape> compute_shape_unsqueeze(
|
||||
BuildUnsqueezedDimensions(input_shape.sizes(), dim))};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_select_scatter(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
int64_t dim,
|
||||
int64_t index) {
|
||||
auto self_meta = at::native::empty_strided_meta(
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
/*dtype=*/c10::make_optional(self.scalar_type()),
|
||||
/*layout=*/c10::make_optional(self.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta(
|
||||
src.sizes(),
|
||||
src.strides(),
|
||||
/*dtype=*/c10::make_optional(src.scalar_type()),
|
||||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::select_scatter(
|
||||
self_meta, src_meta, dim, index);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_diagonal_scatter(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2) {
|
||||
auto self_meta = at::native::empty_strided_meta(
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
/*dtype=*/c10::make_optional(self.scalar_type()),
|
||||
/*layout=*/c10::make_optional(self.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta(
|
||||
src.sizes(),
|
||||
src.strides(),
|
||||
/*dtype=*/c10::make_optional(src.scalar_type()),
|
||||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
|
||||
self_meta, src_meta, offset, dim1, dim2);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_slice_scatter(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
int64_t step) {
|
||||
auto self_meta = at::native::empty_strided_meta(
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
/*dtype=*/c10::make_optional(self.scalar_type()),
|
||||
/*layout=*/c10::make_optional(self.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta(
|
||||
src.sizes(),
|
||||
src.strides(),
|
||||
/*dtype=*/c10::make_optional(src.scalar_type()),
|
||||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::slice_scatter(
|
||||
self_meta, src_meta, dim, start, end, step);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_as_strided_scatter(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
auto self_meta = at::native::empty_strided_meta(
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
/*dtype=*/c10::make_optional(self.scalar_type()),
|
||||
/*layout=*/c10::make_optional(self.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta(
|
||||
src.sizes(),
|
||||
src.strides(),
|
||||
/*dtype=*/c10::make_optional(src.scalar_type()),
|
||||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::as_strided_scatter(
|
||||
self_meta, src_meta, size, stride, storage_offset);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
// Restore unused-parameters warnings
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
|
@ -24,6 +24,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(con
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clone(const at::Tensor & self, c10::optional<at::MemoryFormat> memory_format);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask);
|
||||
@ -56,6 +57,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at:
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
|
||||
@ -84,6 +86,7 @@ TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std:
|
||||
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);
|
||||
|
||||
// View Ops
|
||||
// (Now that functionalization pass is used, we should kill these in a later PR)
|
||||
TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
|
||||
TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
|
||||
TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
|
||||
@ -96,6 +99,12 @@ TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& targ
|
||||
TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
|
||||
TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
|
||||
TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);
|
||||
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset);
|
||||
// clang-format on
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include <torch/csrc/lazy/core/tensor_impl.h>
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace {
|
||||
@ -482,7 +484,8 @@ torch::lazy::Value GetTensorList(c10::ArrayRef<at::Tensor> tensors) {
|
||||
}
|
||||
|
||||
LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) {
|
||||
auto* impl = dynamic_cast<LTCTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
auto* impl = dynamic_cast<LTCTensorImpl*>(
|
||||
maybe_unwrap_functional(tensor).unsafeGetTensorImpl());
|
||||
if (impl == nullptr) {
|
||||
// return c10::make_intrusive<LazyTensor>();
|
||||
return LazyTensorPtr();
|
||||
@ -532,5 +535,27 @@ at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) {
|
||||
return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor)));
|
||||
}
|
||||
|
||||
at::Tensor to_lazy_tensor(
|
||||
const at::Tensor& self,
|
||||
const c10::TensorOptions& options,
|
||||
at::Device device,
|
||||
bool non_blocking,
|
||||
bool functionalize_output) {
|
||||
TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy);
|
||||
TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy);
|
||||
|
||||
auto eager_tensor =
|
||||
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
||||
auto lazy_self = torch::lazy::GetOrCreateLtcTensor(
|
||||
eager_tensor, torch::lazy::atenDeviceToBackendDevice(device));
|
||||
auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
||||
if (functionalize_output) {
|
||||
// See Note [Lazy Tensor Functionalization]
|
||||
return at::functionalization::impl::to_functional_tensor(out);
|
||||
} else {
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -242,6 +242,46 @@ TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
|
||||
TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
|
||||
TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor);
|
||||
|
||||
// Note [Lazy Tensor Functionalization]
|
||||
// The functionalization pass is implemented by wrapping all TensorImpl
|
||||
// objects in C++ with an extra FunctionalTensorWrapper object,
|
||||
// that knows how to perform functionalization
|
||||
//
|
||||
// Certain functions in the aten API serve as entry/exit points for
|
||||
// functionalization, where we need to perform the wrapping/unwrapping:
|
||||
// - aten::to.device
|
||||
// - aten::empty
|
||||
|
||||
// Given a non-lazy tensor, this function creates a lazy tensor on the specified
|
||||
// (lazy) device. The functionalize_output determines whether or not we should
|
||||
// wrap the output in a "functional wrapper".
|
||||
//
|
||||
// How do you know whether to pass true/false for functionalize_output?
|
||||
//
|
||||
// Case 1: nonlazy -> lazy
|
||||
// If you're implementing a function that takes in nonlazy tensors and returns
|
||||
// lazy tensors, then you should think of that function as an "entrypoint" to
|
||||
// functionalization, and use functionalize_output=true Examples include:
|
||||
// - factory functions (the LTC kernel for at::empty)
|
||||
// - CPU -> Lazy device converions (the LTC kernel for at::to_device)
|
||||
//
|
||||
// Case 2: lazy -> lazy
|
||||
// If you're implementing a function that takes in lazy tensors and returns
|
||||
// lazy tensors,
|
||||
// **but** requires creating lazy tensors internally,
|
||||
// then you can assume that the current function is running inside of some
|
||||
// outer context where functionalization is already running, that will take
|
||||
// care of doing the wrapping for you, and use functionalize_output=true
|
||||
// Examples include:
|
||||
// - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel,
|
||||
// converts returns back to lazy tensors).
|
||||
TORCH_API at::Tensor to_lazy_tensor(
|
||||
const at::Tensor& self,
|
||||
const c10::TensorOptions& options,
|
||||
at::Device device,
|
||||
bool non_blocking,
|
||||
bool functionalize_output);
|
||||
|
||||
template <size_t... Indices>
|
||||
auto TupleAtenFromLtcTensorsImpl(
|
||||
const std::vector<LazyTensorPtr>& tensors,
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -62,5 +64,15 @@ at::Scalar MakeIntScalar(T value) {
|
||||
// API returns true.
|
||||
TORCH_API bool IsSpecialScalar(const at::Scalar& value);
|
||||
|
||||
// Note: returns a reference instead of a fresh tensor to avoid refcount bumps.
|
||||
inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) {
|
||||
if (at::functionalization::impl::isFunctionalTensor(tensor)) {
|
||||
return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor)
|
||||
->value();
|
||||
} else {
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/lazy/python/init.h>
|
||||
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <torch/csrc/jit/python/pybind.h>
|
||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||
@ -46,8 +47,9 @@ std::string GetTensorsDump(
|
||||
std::vector<torch::lazy::Node*> nodes;
|
||||
std::vector<torch::lazy::Value> values;
|
||||
for (auto& tensor : tensors) {
|
||||
auto inner = at::functionalization::impl::from_functional_tensor(tensor);
|
||||
torch::lazy::LazyTensorPtr lazy_tensor =
|
||||
torch::lazy::TryGetLtcTensor(tensor);
|
||||
torch::lazy::TryGetLtcTensor(inner);
|
||||
values.push_back(lazy_tensor->GetIrValue());
|
||||
nodes.push_back(values.back().node.get());
|
||||
}
|
||||
|
@ -337,7 +337,11 @@ void ts_eager_fallback(
|
||||
} else {
|
||||
dev_str << "<none>";
|
||||
}
|
||||
TORCH_WARN(
|
||||
// We should never hit this for a view op,
|
||||
// because LazyTensor should provide a lowering for the
|
||||
// corresponding view_copy operator. The functionalization pass will
|
||||
// take care of calling the view_copy operator intead of the view.
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"The operator ",
|
||||
op.schema().operator_name(),
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/MetaFunctions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Operators.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
@ -19,6 +21,8 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace {
|
||||
@ -46,48 +50,8 @@ c10::optional<torch::lazy::BackendDevice> GetLtcDevice(
|
||||
|
||||
} // namespace
|
||||
|
||||
at::Tensor LazyNativeFunctions::alias(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::as_strided(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
auto xsize = torch::lazy::ToI64Vector(size);
|
||||
auto xstride = torch::lazy::ToI64Vector(stride);
|
||||
if (!torch::lazy::StrideIsSupported(xstride)) {
|
||||
return at::native::
|
||||
call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided)>::call(
|
||||
self, size, stride, storage_offset);
|
||||
}
|
||||
return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::as_strided(
|
||||
self_tensor, std::move(xsize), std::move(xstride), storage_offset));
|
||||
}
|
||||
|
||||
const at::Tensor& LazyNativeFunctions::as_strided_(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
auto xsize = torch::lazy::ToI64Vector(size);
|
||||
auto xstride = torch::lazy::ToI64Vector(stride);
|
||||
if (!torch::lazy::StrideIsSupported(xstride)) {
|
||||
return at::native::
|
||||
call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided_)>::call(
|
||||
self, size, stride, storage_offset);
|
||||
}
|
||||
torch::lazy::as_strided_(
|
||||
self_tensor, std::move(xsize), std::move(xstride), storage_offset);
|
||||
return self;
|
||||
}
|
||||
|
||||
// clone is special in LT because we make it a no-op.
|
||||
// This should be safe to do, because every operator in the LT is functional.
|
||||
at::Tensor LazyNativeFunctions::clone(
|
||||
const at::Tensor& self,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
@ -211,12 +175,19 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
|
||||
if (!lazy_self && device && device->type() == c10::kLazy) {
|
||||
// Case 1: eager->lazy (we create a new lazy tensor)
|
||||
|
||||
auto eager_tensor =
|
||||
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
||||
lazy_self = torch::lazy::GetOrCreateLtcTensor(
|
||||
eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device));
|
||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
||||
// See Note [Lazy Tensor Functionalization]
|
||||
// Invariant: if the functionalization key is in the exclude set, then we're
|
||||
// expected to return an ordinary tensor, which will be "lifted" into a
|
||||
// functional wrapper later.
|
||||
bool functionalize_output =
|
||||
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||
c10::DispatchKey::Functionalize);
|
||||
return torch::lazy::to_lazy_tensor(
|
||||
self,
|
||||
options,
|
||||
*device,
|
||||
/*non_blocking=*/non_blocking,
|
||||
/*functionalize_output=*/functionalize_output);
|
||||
} else if (device && device->type() != c10::kLazy) {
|
||||
// Case 2: lazy->eager (forces a graph break since we are materializing a
|
||||
// tensor)
|
||||
@ -298,24 +269,6 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::diagonal(
|
||||
const at::Tensor& self,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
|
||||
auto input = GetLtcTensor(self);
|
||||
auto input_shape = input->shape();
|
||||
dim1 = at::maybe_wrap_dim(dim1, self);
|
||||
dim2 = at::maybe_wrap_dim(dim2, self);
|
||||
auto diagonal_info = DiagonalInfo{offset, dim1, dim2};
|
||||
auto view_info =
|
||||
ViewInfo(ViewInfo::Type::kDiagonal, input_shape, diagonal_info);
|
||||
|
||||
return CreateAtenFromLtcTensor(input->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty(
|
||||
at::IntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
@ -330,7 +283,18 @@ at::Tensor LazyNativeFunctions::empty(
|
||||
.pinned_memory(pin_memory)
|
||||
.dtype(dtype);
|
||||
auto x_result = at::empty(size, options, memory_format);
|
||||
return CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||
// See Note [Lazy Tensor Functionalization]
|
||||
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||
c10::DispatchKey::Functionalize)) {
|
||||
// Invariant: if the functionalization key is in the exclude set, then we're
|
||||
// expected to return an ordinary tensor, which will be "lifted" into a
|
||||
// functional wrapper later.
|
||||
return tensor;
|
||||
} else {
|
||||
auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
|
||||
return wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty_strided(
|
||||
@ -342,16 +306,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
|
||||
c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::expand(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
bool implicit) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
|
||||
return t.as_strided(size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::fill_(
|
||||
@ -412,17 +367,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
|
||||
indices);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::narrow(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t length) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::narrow(self_tensor, dim, start, length));
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::normal_(
|
||||
at::Tensor& self,
|
||||
double mean,
|
||||
@ -453,124 +397,160 @@ at::Tensor& LazyNativeFunctions::normal_(
|
||||
// std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self;
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::permute(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef dims) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::select(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
int64_t index) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::select(torch::lazy::TryGetLtcTensor(self), dim, index));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::slice(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
int64_t step) {
|
||||
int64_t start_val = start.has_value() ? start.value() : 0;
|
||||
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::slice(
|
||||
torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::squeeze_(self_tensor);
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::squeeze_(self_tensor, dim);
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::t_(at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::transpose_(self_tensor, 0, 1);
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::transpose(
|
||||
const at::Tensor& self,
|
||||
int64_t dim0,
|
||||
int64_t dim1) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), dim0, dim1));
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::transpose_(
|
||||
at::Tensor& self,
|
||||
int64_t dim0,
|
||||
int64_t dim1) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::transpose_(self_tensor, dim0, dim1);
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::unsqueeze_(at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::unsqueeze_(self_tensor, dim);
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::view(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
||||
return LazyNativeFunctions::view_copy(self, size);
|
||||
}
|
||||
|
||||
// This is needed by the torch.tensor constructor.
|
||||
// LazyTensor always opts into functionalization.
|
||||
// "lifting" a tensor for functionalization means wrapping it in a
|
||||
// FunctionalTensorWrapper object.
|
||||
at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!at::functionalization::impl::isFunctionalTensor(tensor));
|
||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||
}
|
||||
|
||||
// All of the below ops correspond to CompositeExplicitAutograd kernels from
|
||||
// core that call into view operators internally. These are all composite ops
|
||||
// that LTC can technically re-use / get for free, but we need to
|
||||
// "functionalize" them to remove the view ops before we can use them.
|
||||
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
block_diag)>::call(tensors);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::new_empty_strided(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return at::functionalization::
|
||||
functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(
|
||||
self, size, stride, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::narrow_copy(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t length) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
narrow_copy)>::call(self, dim, start, length);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::pixel_shuffle(
|
||||
const at::Tensor& self,
|
||||
int64_t upscale_factor) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
pixel_shuffle)>::call(self, upscale_factor);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::pixel_unshuffle(
|
||||
const at::Tensor& self,
|
||||
int64_t downscale_factor) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
pixel_unshuffle)>::call(self, downscale_factor);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::select_backward(
|
||||
const at::Tensor& grad_output,
|
||||
at::IntArrayRef input_sizes,
|
||||
int64_t dim,
|
||||
int64_t index) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
select_backward)>::call(grad_output, input_sizes, dim, index);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::_trilinear(
|
||||
const at::Tensor& i1,
|
||||
const at::Tensor& i2,
|
||||
const at::Tensor& i3,
|
||||
at::IntArrayRef expand1,
|
||||
at::IntArrayRef expand2,
|
||||
at::IntArrayRef expand3,
|
||||
at::IntArrayRef sumdim,
|
||||
int64_t unroll_dim) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
|
||||
call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
|
||||
}
|
||||
::std::tuple<at::Tensor, at::Tensor> LazyNativeFunctions::linalg_inv_ex(
|
||||
const at::Tensor& self,
|
||||
bool check_errors) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
linalg_inv_ex)>::call(self, check_errors);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||
const at::Tensor& self,
|
||||
const c10::optional<at::Tensor>& atol,
|
||||
const c10::optional<at::Tensor>& rtol,
|
||||
bool hermitian) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP2(
|
||||
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||
}
|
||||
|
||||
// functionalize_aten_op can't handle out= ops directly.
|
||||
// Instead, we can call the composite kernel from core, and copy and mutations
|
||||
// back to the inputs.
|
||||
at::Tensor& LazyNativeFunctions::logsumexp_out(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef dim,
|
||||
bool keepdim,
|
||||
at::Tensor& out) {
|
||||
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
|
||||
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
|
||||
// directly call the composite kernel from core.
|
||||
// Make sure to re-enable functionalization first.
|
||||
auto curr_tls = c10::impl::tls_local_dispatch_key_set();
|
||||
auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
|
||||
tls_reenable_functionalize.set_included(curr_tls.included_);
|
||||
tls_reenable_functionalize.set_excluded(
|
||||
curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
|
||||
c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
|
||||
at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped);
|
||||
auto out_unwrapped =
|
||||
at::functionalization::impl::from_functional_tensor(out_wrapped);
|
||||
// propagate mutations back to the inputs (including resizing)
|
||||
out.resize_(out_unwrapped.sizes());
|
||||
out.copy_(out_unwrapped);
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::diagonal_backward(
|
||||
const at::Tensor& grad_output,
|
||||
at::IntArrayRef input_sizes,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::slice_backward(
|
||||
const at::Tensor& grad_output,
|
||||
at::IntArrayRef input_sizes,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t step) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
||||
}
|
||||
|
||||
// re-use the composite kernel from core, that way we don't need to provide a
|
||||
// backwards formula for native_group_norm
|
||||
std::tuple<Tensor, Tensor, Tensor> LazyNativeFunctions::native_group_norm(
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
int64_t N,
|
||||
int64_t C,
|
||||
int64_t HxW,
|
||||
int64_t group,
|
||||
double eps) {
|
||||
return at::native::math_group_norm(
|
||||
input, weight, bias, N, C, HxW, group, eps);
|
||||
}
|
||||
|
||||
void InitializeAtenBindings() {}
|
||||
|
@ -486,6 +486,25 @@ class CUDATestBase(DeviceTypeTestBase):
|
||||
# Acquires the current device as the primary (test) device
|
||||
cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device())
|
||||
|
||||
# See Note [Lazy Tensor tests in device agnostic testing]
|
||||
lazy_ts_backend_init = False
|
||||
class LazyTestBase(DeviceTypeTestBase):
|
||||
device_type = 'lazy'
|
||||
|
||||
def _should_stop_test_suite(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
import torch._lazy
|
||||
import torch._lazy.metrics
|
||||
import torch._lazy.ts_backend
|
||||
global lazy_ts_backend_init
|
||||
if not lazy_ts_backend_init:
|
||||
# Need to connect the TS backend to lazy key before running tests
|
||||
torch._lazy.ts_backend.init()
|
||||
lazy_ts_backend_init = True
|
||||
|
||||
class MPSTestBase(DeviceTypeTestBase):
|
||||
device_type = 'mps'
|
||||
|
||||
@ -570,7 +589,7 @@ PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = 'PYTORCH_TESTING_DEVICE_EXCEPT_FOR'
|
||||
# The tests in these test cases are derived from the generic tests in
|
||||
# generic_test_class.
|
||||
# See note "Generic Device Type Testing."
|
||||
def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None):
|
||||
def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None, include_lazy=False):
|
||||
# Removes the generic test class from its enclosing scope so its tests
|
||||
# are not discoverable.
|
||||
del scope[generic_test_class.__name__]
|
||||
@ -592,6 +611,14 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on
|
||||
# Filter out the device types based on user inputs
|
||||
desired_device_type_test_bases = filter_desired_device_types(device_type_test_bases,
|
||||
except_for, only_for)
|
||||
if include_lazy:
|
||||
# Note [Lazy Tensor tests in device agnostic testing]
|
||||
# Right now, test_view_ops.py runs with LazyTensor.
|
||||
# We don't want to opt every device-agnostic test into using the lazy device,
|
||||
# because many of them will fail.
|
||||
# So instead, the only way to opt a specific device-agnostic test file into
|
||||
# lazy tensor testing is with include_lazy=True
|
||||
desired_device_type_test_bases.append(LazyTestBase)
|
||||
|
||||
def split_if_not_empty(x: str):
|
||||
return x.split(",") if len(x) != 0 else []
|
||||
|
@ -34,11 +34,12 @@ from torchgen.api.types import (
|
||||
tensorT,
|
||||
voidT,
|
||||
longT,
|
||||
SymIntT,
|
||||
symIntArrayRefT,
|
||||
BaseTypeToCppMapping,
|
||||
intArrayRefT,
|
||||
optionalIntArrayRefT,
|
||||
tensorOptionsT,
|
||||
symIntArrayRefT,
|
||||
)
|
||||
from torchgen import local
|
||||
from torchgen.utils import assert_never
|
||||
@ -155,6 +156,11 @@ def argumenttype_type(
|
||||
return NamedCType(binds, VectorCType(BaseCType(longT)))
|
||||
else:
|
||||
return NamedCType(binds, BaseCType(intArrayRefT))
|
||||
if str(t.elem) == "SymInt":
|
||||
if remove_non_owning_ref_types:
|
||||
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
|
||||
else:
|
||||
return NamedCType(binds, BaseCType(symIntArrayRefT))
|
||||
elif str(t.elem) == "Tensor":
|
||||
return NamedCType(binds, BaseCType(tensorListT))
|
||||
elif str(t.elem) == "Scalar":
|
||||
|
@ -14,6 +14,8 @@ from torchgen.api.types import (
|
||||
memoryFormatT,
|
||||
tensorOptionsT,
|
||||
scalarTypeT,
|
||||
SymIntT,
|
||||
symIntArrayRefT,
|
||||
boolT,
|
||||
deviceT,
|
||||
layoutT,
|
||||
@ -63,6 +65,7 @@ options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
|
||||
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
|
||||
|
||||
longVec_ctype = VectorCType(BaseCType(longT))
|
||||
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
|
||||
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
|
||||
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
|
||||
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
|
||||
@ -324,7 +327,19 @@ Check this module for more information.
|
||||
|
||||
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
|
||||
elif goal.type == BaseCType(intArrayRefT):
|
||||
return direct_solve(NamedCType(goal.name, longVec_ctype))
|
||||
try:
|
||||
return direct_solve(NamedCType(goal.name, longVec_ctype))
|
||||
except UnsatError:
|
||||
# We can also go SymIntArrayRef -> IntArrayRef
|
||||
symIntArrayRef_type = direct_solve(
|
||||
NamedCType(goal.name, BaseCType(symIntArrayRefT))
|
||||
)
|
||||
return f"c10::asIntArrayRefSlow({symIntArrayRef_type})"
|
||||
elif goal.type == BaseCType(symIntArrayRefT):
|
||||
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
|
||||
elif goal.type == BaseCType(longT):
|
||||
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
|
||||
return f"{symInt_type}.expectInt()"
|
||||
elif goal.type == BaseCType(optionalIntArrayRefT):
|
||||
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
|
||||
elif goal.type == BaseCType(optionalScalarRefT):
|
||||
@ -345,6 +360,10 @@ Check this module for more information.
|
||||
intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
|
||||
argname = direct_solve(intArrayRef_ctype)
|
||||
return f"{argname}.vec()"
|
||||
if goal.type == VectorCType(BaseCType(SymIntT)):
|
||||
symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
|
||||
argname = direct_solve(symIntArrayRef_ctype)
|
||||
return f"{argname}.vec()"
|
||||
elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
|
||||
optionalIntArrayRef_ctype = NamedCType(
|
||||
goal.name, BaseCType(optionalIntArrayRefT)
|
||||
|
@ -1,21 +1,26 @@
|
||||
from abc import ABC
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
Argument,
|
||||
BackendIndex,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
FunctionSchema,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
DispatcherSignature,
|
||||
OptionalCType,
|
||||
VectorCType,
|
||||
kernel_signature,
|
||||
deviceT,
|
||||
)
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.lazy import (
|
||||
LazyIrProperties,
|
||||
LazyIrSchema,
|
||||
@ -121,6 +126,25 @@ def aten_symbol(schema: LazyIrSchema) -> str:
|
||||
return schema.aten_name
|
||||
|
||||
|
||||
# converts all tensor-like arguments to meta tensors. Returns:
|
||||
# (1) a string containing all of the logic that does the conversions.
|
||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
||||
context: List[Binding] = []
|
||||
unwrapped_tensor_args: List[str] = []
|
||||
for arg in sig.arguments():
|
||||
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
||||
unwrapped_name = f"{arg.name}_meta"
|
||||
unwrapped_tensor_args.append(
|
||||
f"auto {unwrapped_name} = to_meta({arg.name});"
|
||||
)
|
||||
context.append(arg.with_name(unwrapped_name))
|
||||
else:
|
||||
context.append(arg)
|
||||
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
|
||||
return unwrap_tensor_args_str, context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenLazyIR(ABC):
|
||||
backend_index: BackendIndex
|
||||
@ -206,7 +230,13 @@ class GenLazyIR(ABC):
|
||||
node_ctor_args = ", ".join(ctor_args)
|
||||
|
||||
scalar_initializers = ",\n ".join(
|
||||
f"{a.name}({a.name})" for a in scalar_args
|
||||
[
|
||||
# This code is just special casing the mapping from string_view -> strings
|
||||
f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
|
||||
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
|
||||
else f"{a.name}({a.name})"
|
||||
for a in scalar_args
|
||||
]
|
||||
)
|
||||
if len(scalar_initializers):
|
||||
scalar_initializers = f",\n {scalar_initializers}"
|
||||
@ -214,6 +244,8 @@ class GenLazyIR(ABC):
|
||||
[
|
||||
f"std::string {a.name};"
|
||||
if a.lazy_type.cpp_type() == "c10::string_view"
|
||||
else f"c10::optional<std::string> {a.name};"
|
||||
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
|
||||
else f"{a.lazy_type.cpp_type()} {a.name};"
|
||||
for a in scalar_args
|
||||
]
|
||||
@ -314,19 +346,20 @@ class GenTSLazyIR(GenLazyIR):
|
||||
elif not schema.properties.CanBeReused:
|
||||
return ""
|
||||
value_comparison = []
|
||||
for arg in schema.positional_values:
|
||||
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
|
||||
if isinstance(arg.lazy_type, OptionalCType):
|
||||
value_comparison.append(
|
||||
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
|
||||
)
|
||||
else:
|
||||
value_comparison.append(f"operand(i++) == {arg.name}")
|
||||
for arg in schema.positional_scalars:
|
||||
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
||||
for arg in schema.keyword_values:
|
||||
value_comparison.append(f"operand(i++) == {arg.name}")
|
||||
for arg in schema.keyword_scalars:
|
||||
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
||||
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
|
||||
if isinstance(arg.lazy_type, OptionalCType):
|
||||
value_comparison.append(
|
||||
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
|
||||
)
|
||||
else:
|
||||
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
||||
value_comparison_str = " &&\n ".join(value_comparison)
|
||||
|
||||
return f"""{signature} {{
|
||||
@ -428,9 +461,20 @@ class GenLazyNativeFuncDefinition:
|
||||
all_args = schema.filtered_args()
|
||||
returns_length = len(schema.returns)
|
||||
# call the meta kernel if it exists, to compute output shape/dtype for our IR
|
||||
if func.structured or func.structured_delegate is not None:
|
||||
meta_out = """std::vector<torch::lazy::Shape> shapes{
|
||||
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
|
||||
# Note [Generated LTC Shape Functions]
|
||||
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
|
||||
# we generate a shape function declaration that needs to be manually implemented.
|
||||
# How do we detect which ops are eligible to use meta tensors?
|
||||
# In general we should be able to use meta tensors not just on structured operators,
|
||||
# but also on composite operators that are implemented in terms of structured kernels.
|
||||
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
|
||||
# This is the case for all view and view_copy operators however, so we're going to
|
||||
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
|
||||
is_view_copy_op = "view_copy" in func.tags
|
||||
is_structured = func.structured or func.structured_delegate is not None
|
||||
if is_structured or is_view_copy_op:
|
||||
meta_out = """
|
||||
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
|
||||
if returns_length > 1:
|
||||
|
||||
def this_shape(i: int) -> str:
|
||||
@ -439,8 +483,28 @@ class GenLazyNativeFuncDefinition:
|
||||
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
|
||||
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
|
||||
|
||||
shape_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)});
|
||||
{meta_out}"""
|
||||
# Convert tensor args to the meta device and call it.
|
||||
# (We can't pass in the input tensors directly, because they are "functional wrappers".
|
||||
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
|
||||
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
|
||||
dispatcher_sig = DispatcherSignature.from_schema(func.func)
|
||||
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
||||
meta_call_args = [
|
||||
e.expr
|
||||
for e in translate(
|
||||
meta_call_ctx, dispatcher_sig.arguments(), method=False
|
||||
)
|
||||
]
|
||||
if is_view_copy_op:
|
||||
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
|
||||
assert func.has_composite_explicit_autograd_non_functional_kernel
|
||||
dispatch_ns = "compositeexplicitautogradnonfunctional"
|
||||
else:
|
||||
dispatch_ns = "meta"
|
||||
shape_str = f"""\
|
||||
{meta_conversion_str}
|
||||
auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)});
|
||||
{meta_out}"""
|
||||
else:
|
||||
shape_sig = ComputeShapeSignature(metadata.kernel, func)
|
||||
shape_str = f"""
|
||||
@ -571,13 +635,14 @@ class GenLazyShapeInferenceDefinition:
|
||||
metadata = self.backend_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
|
||||
# Only generate shape/dtype fn for non-structured kernels,
|
||||
# since we just use the meta function for structured kernels
|
||||
if not f.structured and f.structured_delegate is None:
|
||||
# See Note [Generated LTC Shape Functions]
|
||||
is_view_copy_op = "view_copy" in f.tags
|
||||
is_structured = f.structured or f.structured_delegate is not None
|
||||
if is_structured or is_view_copy_op:
|
||||
return []
|
||||
else:
|
||||
shape_sig = ComputeShapeSignature(metadata.kernel, f)
|
||||
return ["\n".join([f"{shape_sig.shape_decl};"])]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def generate_non_native_lazy_ir_nodes(
|
||||
|
@ -77,6 +77,7 @@ from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_composite_view_copy_kernel,
|
||||
gen_symint_view_copy_kernel,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -2281,6 +2282,24 @@ TORCH_LIBRARY({custom_namespace}, m) {{
|
||||
)
|
||||
},
|
||||
)
|
||||
view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = []
|
||||
for g1 in view_groups:
|
||||
for g2 in view_groups:
|
||||
if g1.view_copy is None or g2.view_copy is None:
|
||||
continue
|
||||
# TODO: make this more first class in the data model
|
||||
same_base_op = str(g1.view_copy.func.name.name) == str(
|
||||
g2.view_copy.func.name.name
|
||||
)
|
||||
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
|
||||
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
|
||||
if same_base_op and op1_not_symint and op2_symint:
|
||||
view_copy_with_symint_pairs.append(
|
||||
(
|
||||
g1.view_copy,
|
||||
g2.view_copy,
|
||||
)
|
||||
)
|
||||
|
||||
# Note [view_copy NativeFunctions]
|
||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||
@ -2321,6 +2340,12 @@ TORCH_LIBRARY({custom_namespace}, m) {{
|
||||
"CompositeViewCopyKernel_Definitions": list(
|
||||
mapMaybe(gen_composite_view_copy_kernel, view_groups)
|
||||
),
|
||||
"SymIntViewCopyKernel_Definitions": list(
|
||||
mapMaybe(
|
||||
lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]),
|
||||
view_copy_with_symint_pairs,
|
||||
)
|
||||
),
|
||||
"GeneratedCompositeFunctional_Definitions": list(
|
||||
mapMaybe(
|
||||
gen_composite_functional_kernel,
|
||||
|
@ -265,7 +265,10 @@ def error_on_missing_kernels(
|
||||
native_f
|
||||
)
|
||||
|
||||
kernel_defn_regex = rf"{class_name}::([\w\d]*)\([^\)]*\)\s*{{"
|
||||
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
|
||||
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
|
||||
# here, then we get a nicer error message. If we miss it, you get a linker error.
|
||||
kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
|
||||
actual_backend_kernel_name_counts = Counter(
|
||||
re.findall(kernel_defn_regex, backend_defns)
|
||||
)
|
||||
|
@ -75,8 +75,30 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
||||
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
|
||||
# For view_copy.SymInt overloads,
|
||||
# See gen_symint_view_copy_kernel.
|
||||
if g.view_copy.func.name.overload_name == "SymInt":
|
||||
return None
|
||||
|
||||
# We can make view_copy work in more cases by using reshape()
|
||||
# when a normal view call would ordinarily fail.
|
||||
# This also makes LTC more efficient, because they don't need to include
|
||||
# clone() calls in their graph (which is normally needed by reshape).
|
||||
if str(g.view_copy.func.name) == "view_copy":
|
||||
return """\
|
||||
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
|
||||
if (!at::detail::computeStride(self.sizes(), self.strides(), size).has_value()) {
|
||||
return self.reshape(size);
|
||||
} else {
|
||||
auto output = at::_ops::view::call(self, size);
|
||||
return output.clone();
|
||||
}
|
||||
}
|
||||
"""
|
||||
# view_copy is a native signature, since we're generating an at::native:: kernel
|
||||
view_copy_sig = NativeSignature(g.view_copy.func)
|
||||
|
||||
# view is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_sig = DispatcherSignature(g.view.func)
|
||||
|
||||
@ -113,6 +135,34 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
||||
"""
|
||||
|
||||
|
||||
# For symint view copy kernels, we want to generate them to call into
|
||||
# their concrete view_copy counterparts.
|
||||
@with_native_function_and
|
||||
def gen_symint_view_copy_kernel(
|
||||
view_copy: NativeFunction, view_copy_symint: NativeFunction
|
||||
) -> str:
|
||||
# view_copy.symint is a native signature, since we're generating an at::native:: kernel
|
||||
view_copy_symint_sig = NativeSignature(view_copy_symint.func)
|
||||
|
||||
# view_copy is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_copy_sig = DispatcherSignature(view_copy.func)
|
||||
|
||||
exprs = ", ".join(
|
||||
[
|
||||
e.expr
|
||||
for e in translate(
|
||||
view_copy_symint_sig.arguments(), view_copy_sig.arguments()
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return f"""
|
||||
{view_copy_symint_sig.defn()} {{
|
||||
return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs});
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
||||
assert len(rets) == len(names)
|
||||
if len(rets) == 0:
|
||||
|
@ -162,6 +162,39 @@ and implement it in the the corresponding shape_inference.cpp file.\n
|
||||
)
|
||||
|
||||
|
||||
# Some helper functions for the codegen.
|
||||
def get_ltc_helper_fns() -> str:
|
||||
return """\
|
||||
at::Tensor to_meta(const at::Tensor& tensor) {
|
||||
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
|
||||
if (!tensor.defined()) return tensor;
|
||||
auto out = at::native::empty_strided_meta(tensor.sizes(), tensor.strides(), \
|
||||
/*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt);
|
||||
// needs to handle wrapped numbers, so dtype promotion works properly.
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
|
||||
if (tensor.has_value()) {
|
||||
return to_meta(*tensor);
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> to_meta(const at::TensorList& t_list) {
|
||||
std::vector<at::Tensor> outs;
|
||||
outs.reserve(t_list.size());
|
||||
for (const auto& i : c10::irange(t_list.size())) {
|
||||
outs.push_back(to_meta(t_list[i]));
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class default_args:
|
||||
node_base: str = "Node"
|
||||
node_base_hdr: Optional[str] = None
|
||||
@ -436,6 +469,9 @@ def run_gen_lazy_tensor(
|
||||
tensor_class_hdr,
|
||||
shape_inference_hdr,
|
||||
"ATen/Functions.h",
|
||||
"ATen/native/TensorConversions.h",
|
||||
"ATen/NativeFunctions.h",
|
||||
"ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
|
||||
"ATen/MetaFunctions.h",
|
||||
"ATen/Operators.h",
|
||||
"ATen/native/CPUFallback.h",
|
||||
@ -452,6 +488,7 @@ def run_gen_lazy_tensor(
|
||||
else []
|
||||
)
|
||||
],
|
||||
"helper_fns": get_ltc_helper_fns(),
|
||||
"native_functions_include": "",
|
||||
"namespace_prologue": ns_helper.prologue,
|
||||
"namespace_epilogue": ns_helper.epilogue,
|
||||
|
Reference in New Issue
Block a user