functionalization <> LTC integration (take 3) (#80251)

new PR for https://github.com/pytorch/pytorch/pull/75527.

It looks like there's a bug in the windows CI scripts that was causing
flaky failures, that disappear when I create a new PR. example failure:
https://github.com/pytorch/pytorch/runs/6999272635?check_suite_focus=true
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80251
Approved by: https://github.com/wconstab
This commit is contained in:
Brian Hirsh
2022-06-26 07:42:42 -07:00
committed by PyTorch MergeBot
parent 33761c80d2
commit c2d395cf8e
27 changed files with 1059 additions and 294 deletions

View File

@ -22,6 +22,9 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
refresh_numel();
refresh_contiguous();
storage_access_should_throw_ = false;
// In general, the sizes/stride metadata on a tensor can change as it is mutated,
// and these changes need to be reflected in the metadata of the wrapper.
set_allow_tensor_metadata_change(true);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
// All of the keys corresponding to functorch transforms should not be copied over.
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
@ -180,8 +183,12 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
value_ = other;
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
// We need to propagate that metadata mutation to the wrapper (new size).
set_sizes_and_strides(value_.sizes(), value_.strides());
set_storage_offset(value_.storage_offset());
if (sizes() != value_.sizes() || strides() != value_.strides()) {
set_sizes_and_strides(value_.sizes(), value_.strides());
}
if (storage_offset() != value_.storage_offset()) {
set_storage_offset(value_.storage_offset());
}
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
value_ = value_.to(c10::TensorOptions().dtype(dtype()).layout(layout()));
}
@ -260,6 +267,44 @@ const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
return "FunctionalTensorWrapper";
}
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
if (key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
if (r) {
r->set_version_counter(std::forward<VariableVersion>(version_counter));
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
}
auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/std::forward<VariableVersion>(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
return shallow_copy_and_detach_core(
version_counter, allow_tensor_metadata_change);
}
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
return shallow_copy_and_detach_core(
std::move(version_counter), allow_tensor_metadata_change);
}
at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
return value_.unsafeGetTensorImpl()->sizes();
}
@ -275,6 +320,9 @@ int64_t FunctionalTensorWrapper::numel_custom() const {
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
return value_.unsafeGetTensorImpl()->is_contiguous();
}
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes() const {
return value_.unsafeGetTensorImpl()->sym_sizes();
}
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
return value_.unsafeGetTensorImpl()->sym_sizes();
}
@ -329,7 +377,7 @@ std::vector<Tensor> to_functional_tensor(const TensorList& t_list) {
Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
// Note [Wrapped Numbers <> Functionalization]
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
return tensor;
}
if (isFunctionalTensor(tensor)) {
@ -454,41 +502,59 @@ bool isFunctionalTensor(const c10::optional<Tensor>& t) {
bool isFunctionalTensor(const c10::List<Tensor>& t_list) {
if (t_list.size() == 0) return false;
bool any_functional = isFunctionalTensor(t_list[0]);
for (const auto i : c10::irange(1, t_list.size())) {
auto curr_functional = isFunctionalTensor(t_list[i]);
TORCH_INTERNAL_ASSERT(
curr_functional == any_functional,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
auto functional_count = 0;
auto nonfunctional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].defined()) continue;
if (isFunctionalTensor(t_list[i])) {
++functional_count;
} else {
++nonfunctional_count;
}
}
return any_functional;
TORCH_INTERNAL_ASSERT(
functional_count == 0 || nonfunctional_count == 0,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
return functional_count > 0;
}
bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
if (t_list.size() == 0) return false;
bool any_functional = isFunctionalTensor(t_list[0]);
for (const auto i : c10::irange(1, t_list.size())) {
auto curr_functional = isFunctionalTensor(t_list[i]);
TORCH_INTERNAL_ASSERT(
curr_functional == any_functional,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
auto functional_count = 0;
auto nonfunctional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].has_value() || !t_list[i]->defined()) continue;
if (isFunctionalTensor(t_list[i])) {
++functional_count;
} else {
++nonfunctional_count;
}
}
return any_functional;
TORCH_INTERNAL_ASSERT(
functional_count == 0 || nonfunctional_count == 0,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
return functional_count > 0;
}
bool isFunctionalTensor(const c10::ArrayRef<Tensor> t_list) {
if (t_list.size() == 0) return false;
bool any_functional = isFunctionalTensor(t_list[0]);
for (const auto i : c10::irange(1, t_list.size())) {
auto curr_functional = isFunctionalTensor(t_list[i]);
TORCH_INTERNAL_ASSERT(
curr_functional == any_functional,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
auto functional_count = 0;
auto nonfunctional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].defined()) continue;
if (isFunctionalTensor(t_list[i])) {
++functional_count;
} else {
++nonfunctional_count;
}
}
return any_functional;
TORCH_INTERNAL_ASSERT(
functional_count == 0 || nonfunctional_count == 0,
"Functionalization encountered a list of tensors where some are functional",
"and some are not, which is not currently unsupported.");
return functional_count > 0;
}
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
@ -552,5 +618,93 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) {
}
} // namespace impl
// Given an **out-of-place** op that might internally call view/inplace ops,
// This function will "functionalize" it.
// That is, it will call the operator, but removing any intermediate views/mutations
// that are performed inside of it.
// This is useful for LTC/XLA, which would like to re-use some of our composite kernels
// from pytorch core but not have to worry about the view ops that they might call.
// e.g. at::block_diag
void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
// Wrap all tensor-like inputs into FunctionalTensorWrappers.
// When we re-invoke the dispatcher, this will automatically enable the functionalization pass.
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
auto t = ivalue.toTensor();
if (t.defined()) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
auto tensors = ivalue.toTensorList();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList()) {
auto opt_tensors = ivalue.toOptionalTensorList();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
{
// Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap
// the output in a functional tensor based on TLS.
// In this code, we're re-entrantly entering functionalization in the same call-stack,
// so we need to manually fix up TLS as if it hadn't already been called.
auto curr_tls = c10::impl::tls_local_dispatch_key_set();
auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
tls_reenable_functionalize.set_included(curr_tls.included_);
tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
// So, we should probably provide a way to directly call a kernel registered to
// the `CompositeExplicitAutograd` key.
// We can't do that today, so this should be a reasonably good proxy
// (It won't work in cases where an op has both a CompositeExplicitAutograd kernel
// AND a dedicated meta kernel, but that probably shouldn't ever happen).
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor()) {
auto t = ivalue.toTensor();
if (!t.defined()) continue;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList()) {
auto tensors = ivalue.toTensorList();
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList()) {
auto opt_tensors = ivalue.toOptionalTensorList();
at::functionalization::impl::sync(opt_tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
} // namespace functionalization
} // namespace at

View File

@ -4,6 +4,8 @@
#include <ATen/ArrayRef.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/core/List.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/DispatchKey.h>
@ -120,6 +122,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
~FunctionalTensorWrapper() override = default;
// FunctionalTensorWrapper overrides all custom size/stride function,
@ -130,6 +140,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
int64_t dim_custom() const override;
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
private:
@ -137,6 +148,16 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
void set_constructor_metadata();
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
// This is used to re-implement shallow_copy_and_detach for
// FunctionalTensorWrapper. The implementation is identical, but we just need
// to return a subclass instead of a plain TensorImpl.
// TODO: maybe it's possible to arrange for that to happen automatically
// without an override here?
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
// Note that value is not taken by reference: internally, the wrapper will
// change the value tensor that it points to over time.
Tensor value_;
@ -251,5 +272,37 @@ class TORCH_API FunctionalizationReapplyViewsGuard {
};
} // namespace impl
// Helper function to call an out-of-place composite aten kernel that may use
// mutations / views internally, and functionalize them.
TORCH_API void functionalize_op_helper(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
template <class Op, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op final {};
template <class Op, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op<Op, ReturnType(ParameterTypes...)> final {
static ReturnType call(ParameterTypes... args) {
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow(
(const char*)Op::name, (const char*)Op::overload_name)
.typed<ReturnType(ParameterTypes...)>();
return c10::impl::BoxedKernelWrapper<ReturnType(ParameterTypes...)>::call(
c10::KernelFunction::make_boxed_function<functionalize_op_helper>,
nullptr,
op,
// BoxedKernelWrapper knows to ignore this keyset argument,
// because functionalize_op_helper doesn't take in a DispatchKeySet
c10::DispatchKeySet(),
args...);
}
};
template <class Op>
using functionalize_aten_op = _functionalize_aten_op<Op, typename Op::schema>;
} // namespace functionalization
} // namespace at

View File

@ -10,6 +10,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_to_copy.h>
#include <ATen/ops/to_native.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/as_strided.h>
@ -32,7 +33,7 @@ namespace {
if (ivalue.isTensor()) {
any_tensor_inputs = true;
auto t = ivalue.toTensor();
if (at::functionalization::impl::isFunctionalTensor(t)) {
if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
any_functional_inputs = true;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
@ -73,6 +74,7 @@ namespace {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
auto t = ivalue.toTensor();
if (!t.defined()) continue;
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
@ -171,6 +173,48 @@ at::Tensor lift_functionalize(const at::Tensor & self) {
return at::functionalization::impl::to_functional_tensor(self);
}
bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
// If the target device is empty, then the output tensor should be on the same device as the input
auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
}
// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
// We should probably get rid of this though.
at::Tensor _to_copy_functionalize(
const at::Tensor & self,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
// sync any pending updates
at::functionalization::impl::sync(self);
// pass the unwrapped tensor to the backend
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
at::AutoDispatchSkipFunctionalize guard;
auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
// Special case: if the Functionalize key is not in TLS, we assume that we're running
// on a lazy backend (LTC).
// In that case, if we're copying to a non-functionalize-enabled device,
// then the functionalization pass should "end". We need to sync any updates on the input
// tensor, but we shouldn't wrap the output.
if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
if (!device_opted_into_functionalization(self.device(), device)) {
return out;
}
}
return at::functionalization::impl::to_functional_tensor(out);
}
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}
@ -178,4 +222,5 @@ TORCH_LIBRARY_IMPL(_, Functionalize, m) {
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("resize_", TORCH_FN(resize__functionalization));
m.impl("lift", TORCH_FN(lift_functionalize));
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
}

View File

@ -20,5 +20,6 @@ bool to_will_alias(
Tensor to_meta(const Tensor& tensor);
c10::optional<Tensor> to_meta(const c10::optional<Tensor>& tensor);
std::vector<Tensor> to_meta(const at::TensorList& t_list);
} // namespace native
} // namespace at

View File

@ -765,7 +765,7 @@
device_guard: False
tags: inplace_view
dispatch:
CompositeExplicitAutograd: as_strided_
CompositeExplicitAutogradNonFunctional: as_strided_
- func: asin(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -1976,7 +1976,7 @@
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
dispatch:
CompositeExplicitAutograd: new_empty_strided
CompositeExplicitAutogradNonFunctional: new_empty_strided
- func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
@ -3027,7 +3027,8 @@
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: logsumexp_out
# calls squeeze
CompositeExplicitAutogradNonFunctional: logsumexp_out
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
@ -3637,12 +3638,12 @@
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
dispatch:
CPU: pixel_shuffle_cpu
CompositeExplicitAutograd: math_pixel_shuffle
CompositeExplicitAutogradNonFunctional: math_pixel_shuffle
- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
dispatch:
CPU: pixel_unshuffle_cpu
CompositeExplicitAutograd: math_pixel_unshuffle
CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle
- func: channel_shuffle(Tensor self, int groups) -> Tensor
dispatch:
@ -4088,7 +4089,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: select_backward
CompositeExplicitAutogradNonFunctional: select_backward
- func: selu(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -4920,7 +4921,8 @@
- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
dispatch:
CompositeExplicitAutograd: _trilinear
# calls unsqueeze
CompositeExplicitAutogradNonFunctional: _trilinear
- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
@ -11648,13 +11650,15 @@
python_module: linalg
variants: function
dispatch:
CompositeExplicitAutograd: linalg_inv_ex
# calls transpose_
CompositeExplicitAutogradNonFunctional: linalg_inv_ex
- func: linalg_inv_ex.inverse(Tensor self, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
python_module: linalg
variants: function
dispatch:
CompositeExplicitAutograd: linalg_inv_ex_out
# calls transpose_
CompositeExplicitAutogradNonFunctional: linalg_inv_ex_out
- func: linalg_inv(Tensor self) -> Tensor
python_module: linalg
@ -11766,7 +11770,9 @@
python_module: linalg
variants: function
dispatch:
CompositeExplicitAutograd: linalg_pinv
# calls svd, which calls mH() (view op)
# also calls narrow()
CompositeExplicitAutogradNonFunctional: linalg_pinv
- func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
@ -12018,7 +12024,7 @@
- func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: expand_copy_SymInt
CompositeExplicitAutograd: expand_copy_SymInt
tags: view_copy
- func: permute_copy(Tensor self, int[] dims) -> Tensor

View File

@ -143,40 +143,65 @@ full_codegen:
- upsample_nearest2d_backward
- zero.functional
- narrow_copy.SymInt
- alias_copy
- as_strided_copy
- diagonal_copy
- expand_copy
- permute_copy
- _reshape_alias_copy
- select_copy.int
- detach_copy
- slice_copy.Tensor
# Not implemented yet because LTC codegen doesn't currently work
# for ops that return lists of tensors.
#- split_copy.Tensor
#- split_with_sizes_copy
#- unbind_copy.int
- squeeze_copy
- squeeze_copy.dim
- t_copy
- transpose_copy.int
- unsqueeze_copy
- view_copy
- view_copy.dtype
- unfold_copy
- select_scatter
- slice_scatter
- diagonal_scatter
- as_strided_scatter
supported:
- as_strided
- as_strided_
- clone
- _copy_from
- _copy_from_and_resize
- diagonal
- empty.memory_format
- empty_strided
- expand
- fill_.Scalar
- narrow
- normal_
- max_pool3d_with_indices
- max_pool3d_with_indices_backward
- permute
- select.int
- slice.Tensor
- squeeze
- squeeze.dim
- squeeze_
- squeeze_.dim
- t
- t_
- _to_copy
- transpose.int
- transpose_
- unsqueeze
- unsqueeze_
- view
- alias
- _unsafe_view
- lift
# Below are all operators that are "composite" in core,
# but require us to explicitly re-enable functionalization in order to use them.
# Why? These operators are all CompositeExplicitAutograd, which mean that they run
# after functionalization,
# but their implementations call view operators (which we need to functionalize away).
- block_diag
- diagonal_backward
- slice_backward
- new_empty_strided
- narrow_copy
- pixel_shuffle
- pixel_unshuffle
- select_backward
- _trilinear
- linalg_inv_ex
- linalg_pinv.atol_rtol_tensor
- logsumexp.out
autograd:
- max_pool3d
- native_group_norm
# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
non_native:

View File

@ -54,6 +54,8 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
${CompositeViewCopyKernel_Definitions}
${SymIntViewCopyKernel_Definitions}
${GeneratedCompositeFunctional_Definitions}
${GeneratedCompositeOut_Definitions}

View File

@ -2,6 +2,10 @@
${includes}
${native_functions_include}
namespace {
${helper_fns}
} // namespace
${namespace_prologue}
${native_function_definitions}

View File

@ -217,7 +217,6 @@ enum class DispatchKey : uint16_t {
// Out-of-core key for Fake Tensor in torchdistx.
// See https://pytorch.org/torchdistx/latest/fake_tensor.html
Fake,
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
// is to insert code after the "autograd subsystem" runs, so this key should
// be directly after ADInplaceOrView and all of the autograd keys.

View File

@ -54,10 +54,25 @@ def init_lists():
'pow', # incorrect results
'addcdiv', # incorrect results (on CI not locally?)
])
# The following ops all show up directly in ts_native_functions.yaml,
# but run functionalized versions of the composite kernels in core.
# This means that we don't expect the ops to show directly in the LTC metrics.
FUNCTIONAL_DECOMPOSE_LIST = set([
'block_diag',
'new_empty_strided',
'narrow_copy',
'pixel_shuffle',
'pixel_unshuffle',
'select_backward',
'_trilinear',
'linalg_inv_ex',
'linalg_pinv.atol_rtol_tensor',
'logsumexp',
])
return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST)
return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST)
(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) = init_lists()
(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST) = init_lists()
torch.manual_seed(42)
@ -96,9 +111,36 @@ class TestLazyTensor(JitTestCase):
torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu())
torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu())
def test_view_mark_step_preserved(self):
test_device = get_test_device()
inp = torch.rand(4, device=test_device)
inp_lazy = clone_move(inp)
def foo(x, *, mark_step):
y = x.view(2, 2)
y.add_(1)
z = x + x
if mark_step:
torch._lazy.mark_step()
# y and x should contiue to be aliased after the mark_step call.
y.add_(1)
return x
out_ref = foo(inp, mark_step=False)
out = foo(inp_lazy, mark_step=True)
# out will have some pending mutations, which will be synced by the .cpu() call.
torch.testing.assert_close(out_ref.cpu(), out.cpu())
class TestLazyOpInfo(TestCase):
@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST], allowed_dtypes=(torch.float,))
@ops([op for op in op_db
if op.name in LAZY_OPS_LIST
and op.name not in SKIP_RUNTIME_ERROR_LIST
and op.name not in FUNCTIONAL_DECOMPOSE_LIST
], allowed_dtypes=(torch.float,))
def test_dispatched_to_lazy(self, device, dtype, op):
def get_name(op):
l = [op.name]

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: tests"]
import torch
import numpy as np
@ -475,6 +474,8 @@ class TestViewOps(TestCase):
v[0] = 0
self.assertEqual(t[2, 0], v[0])
# Lazy hasn't implemented unbind yet.
@onlyNativeDeviceTypes
def test_unbind_view(self, device) -> None:
t = torch.zeros((5, 5), device=device)
tup = torch.unbind(t)
@ -505,6 +506,9 @@ class TestViewOps(TestCase):
stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True)
# TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken,
# causing asserts to trigger.
@onlyNativeDeviceTypes
def test_expand_view(self, device) -> None:
t = torch.ones((5, 1), device=device)
v = t.expand(5, 5)
@ -718,6 +722,8 @@ class TestViewOps(TestCase):
self.assertTrue(s is t)
@skipMeta
# self.is_view_of reports false positives for lazy
@onlyNativeDeviceTypes
def test_contiguous_nonview(self, device):
t = torch.ones(5, 5, device=device)
nv = t.t().contiguous()
@ -744,6 +750,8 @@ class TestViewOps(TestCase):
self.assertEqual(t[1, 1], v[6])
@skipMeta
# self.is_view_of reports false positives for lazy
@onlyNativeDeviceTypes
def test_reshape_nonview(self, device):
t = torch.ones(5, 5, device=device)
nv = torch.reshape(t.t(), (25,))
@ -752,6 +760,9 @@ class TestViewOps(TestCase):
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])
# This test use as_strided to construct a tensor with overlapping memory,
# which is not handled by the functionalization pass.
@onlyNativeDeviceTypes
def test_flatten_view(self, device):
def test_writes_propagate(t, v):
idx_t = (0,) * t.ndim
@ -1820,7 +1831,7 @@ class TestOldViewOps(TestCase):
t.crow_indices()
t.col_indices()
instantiate_device_type_tests(TestViewOps, globals())
instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True)
instantiate_device_type_tests(TestOldViewOps, globals())
if __name__ == '__main__':

View File

@ -50,10 +50,12 @@
#include <torch/csrc/lazy/core/shape_inference.h>
#include <ATen/AccumulateType.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Functions.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/ReduceOpsUtils.h>
@ -599,6 +601,17 @@ std::vector<Shape> compute_shape_mean(
return {Shape(self.scalar_type(), {})};
}
std::vector<Shape> compute_shape_new_empty_strided(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
}
std::vector<Shape> compute_shape_mv(
const at::Tensor& self,
const at::Tensor& vec) {
@ -949,6 +962,12 @@ std::vector<Shape> compute_shape__to_copy(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
TORCH_API std::vector<Shape> compute_shape_clone(
const at::Tensor& self,
c10::optional<at::MemoryFormat> memory_format) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
TORCH_CHECK(tensors.size() > 0, "stack expects a non-empty TensorList");
auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1);
@ -1108,6 +1127,106 @@ std::vector<Shape> compute_shape_unsqueeze(
BuildUnsqueezedDimensions(input_shape.sizes(), dim))};
}
std::vector<Shape> compute_shape_select_scatter(
const at::Tensor& self,
const at::Tensor& src,
int64_t dim,
int64_t index) {
auto self_meta = at::native::empty_strided_meta(
self.sizes(),
self.strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto src_meta = at::native::empty_strided_meta(
src.sizes(),
src.strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::select_scatter(
self_meta, src_meta, dim, index);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<Shape> compute_shape_diagonal_scatter(
const at::Tensor& self,
const at::Tensor& src,
int64_t offset,
int64_t dim1,
int64_t dim2) {
auto self_meta = at::native::empty_strided_meta(
self.sizes(),
self.strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto src_meta = at::native::empty_strided_meta(
src.sizes(),
src.strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
self_meta, src_meta, offset, dim1, dim2);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<Shape> compute_shape_slice_scatter(
const at::Tensor& self,
const at::Tensor& src,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto self_meta = at::native::empty_strided_meta(
self.sizes(),
self.strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto src_meta = at::native::empty_strided_meta(
src.sizes(),
src.strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::slice_scatter(
self_meta, src_meta, dim, start, end, step);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<Shape> compute_shape_as_strided_scatter(
const at::Tensor& self,
const at::Tensor& src,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
auto self_meta = at::native::empty_strided_meta(
self.sizes(),
self.strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto src_meta = at::native::empty_strided_meta(
src.sizes(),
src.strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::as_strided_scatter(
self_meta, src_meta, size, stride, storage_offset);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
// Restore unused-parameters warnings
#pragma GCC diagnostic pop

View File

@ -24,6 +24,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(con
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clone(const at::Tensor & self, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask);
@ -56,6 +57,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at:
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
@ -84,6 +86,7 @@ TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std:
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);
// View Ops
// (Now that functionalization pass is used, we should kill these in a later PR)
TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
@ -96,6 +99,12 @@ TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& targ
TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset);
// clang-format on
} // namespace lazy
} // namespace torch

View File

@ -10,6 +10,8 @@
#include <torch/csrc/lazy/core/tensor_impl.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <ATen/FunctionalTensorWrapper.h>
namespace torch {
namespace lazy {
namespace {
@ -482,7 +484,8 @@ torch::lazy::Value GetTensorList(c10::ArrayRef<at::Tensor> tensors) {
}
LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) {
auto* impl = dynamic_cast<LTCTensorImpl*>(tensor.unsafeGetTensorImpl());
auto* impl = dynamic_cast<LTCTensorImpl*>(
maybe_unwrap_functional(tensor).unsafeGetTensorImpl());
if (impl == nullptr) {
// return c10::make_intrusive<LazyTensor>();
return LazyTensorPtr();
@ -532,5 +535,27 @@ at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) {
return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor)));
}
at::Tensor to_lazy_tensor(
const at::Tensor& self,
const c10::TensorOptions& options,
at::Device device,
bool non_blocking,
bool functionalize_output) {
TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy);
TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy);
auto eager_tensor =
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
auto lazy_self = torch::lazy::GetOrCreateLtcTensor(
eager_tensor, torch::lazy::atenDeviceToBackendDevice(device));
auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self);
if (functionalize_output) {
// See Note [Lazy Tensor Functionalization]
return at::functionalization::impl::to_functional_tensor(out);
} else {
return out;
}
}
} // namespace lazy
} // namespace torch

View File

@ -242,6 +242,46 @@ TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor);
// Note [Lazy Tensor Functionalization]
// The functionalization pass is implemented by wrapping all TensorImpl
// objects in C++ with an extra FunctionalTensorWrapper object,
// that knows how to perform functionalization
//
// Certain functions in the aten API serve as entry/exit points for
// functionalization, where we need to perform the wrapping/unwrapping:
// - aten::to.device
// - aten::empty
// Given a non-lazy tensor, this function creates a lazy tensor on the specified
// (lazy) device. The functionalize_output determines whether or not we should
// wrap the output in a "functional wrapper".
//
// How do you know whether to pass true/false for functionalize_output?
//
// Case 1: nonlazy -> lazy
// If you're implementing a function that takes in nonlazy tensors and returns
// lazy tensors, then you should think of that function as an "entrypoint" to
// functionalization, and use functionalize_output=true Examples include:
// - factory functions (the LTC kernel for at::empty)
// - CPU -> Lazy device converions (the LTC kernel for at::to_device)
//
// Case 2: lazy -> lazy
// If you're implementing a function that takes in lazy tensors and returns
// lazy tensors,
// **but** requires creating lazy tensors internally,
// then you can assume that the current function is running inside of some
// outer context where functionalization is already running, that will take
// care of doing the wrapping for you, and use functionalize_output=true
// Examples include:
// - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel,
// converts returns back to lazy tensors).
TORCH_API at::Tensor to_lazy_tensor(
const at::Tensor& self,
const c10::TensorOptions& options,
at::Device device,
bool non_blocking,
bool functionalize_output);
template <size_t... Indices>
auto TupleAtenFromLtcTensorsImpl(
const std::vector<LazyTensorPtr>& tensors,

View File

@ -3,6 +3,8 @@
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/shape.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <string>
#include <vector>
@ -62,5 +64,15 @@ at::Scalar MakeIntScalar(T value) {
// API returns true.
TORCH_API bool IsSpecialScalar(const at::Scalar& value);
// Note: returns a reference instead of a fresh tensor to avoid refcount bumps.
inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) {
if (at::functionalization::impl::isFunctionalTensor(tensor)) {
return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor)
->value();
} else {
return tensor;
}
}
} // namespace lazy
} // namespace torch

View File

@ -1,5 +1,6 @@
#include <torch/csrc/lazy/python/init.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <c10/core/Device.h>
#include <torch/csrc/jit/python/pybind.h>
#include <torch/csrc/lazy/backend/backend_device.h>
@ -46,8 +47,9 @@ std::string GetTensorsDump(
std::vector<torch::lazy::Node*> nodes;
std::vector<torch::lazy::Value> values;
for (auto& tensor : tensors) {
auto inner = at::functionalization::impl::from_functional_tensor(tensor);
torch::lazy::LazyTensorPtr lazy_tensor =
torch::lazy::TryGetLtcTensor(tensor);
torch::lazy::TryGetLtcTensor(inner);
values.push_back(lazy_tensor->GetIrValue());
nodes.push_back(values.back().node.get());
}

View File

@ -337,7 +337,11 @@ void ts_eager_fallback(
} else {
dev_str << "<none>";
}
TORCH_WARN(
// We should never hit this for a view op,
// because LazyTensor should provide a lowering for the
// corresponding view_copy operator. The functionalization pass will
// take care of calling the view_copy operator intead of the view.
TORCH_CHECK(
false,
"The operator ",
op.schema().operator_name(),

View File

@ -1,5 +1,7 @@
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/Functions.h>
#include <ATen/MetaFunctions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
@ -19,6 +21,8 @@
#include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
#include <torch/library.h>
using at::Tensor;
namespace torch {
namespace lazy {
namespace {
@ -46,48 +50,8 @@ c10::optional<torch::lazy::BackendDevice> GetLtcDevice(
} // namespace
at::Tensor LazyNativeFunctions::alias(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return self;
}
at::Tensor LazyNativeFunctions::as_strided(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto xsize = torch::lazy::ToI64Vector(size);
auto xstride = torch::lazy::ToI64Vector(stride);
if (!torch::lazy::StrideIsSupported(xstride)) {
return at::native::
call_fallback_fn<&ltc_eager_fallback, ATEN_OP(as_strided)>::call(
self, size, stride, storage_offset);
}
return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::as_strided(
self_tensor, std::move(xsize), std::move(xstride), storage_offset));
}
const at::Tensor& LazyNativeFunctions::as_strided_(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
auto xsize = torch::lazy::ToI64Vector(size);
auto xstride = torch::lazy::ToI64Vector(stride);
if (!torch::lazy::StrideIsSupported(xstride)) {
return at::native::
call_fallback_fn<&ltc_eager_fallback, ATEN_OP(as_strided_)>::call(
self, size, stride, storage_offset);
}
torch::lazy::as_strided_(
self_tensor, std::move(xsize), std::move(xstride), storage_offset);
return self;
}
// clone is special in LT because we make it a no-op.
// This should be safe to do, because every operator in the LT is functional.
at::Tensor LazyNativeFunctions::clone(
const at::Tensor& self,
c10::optional<at::MemoryFormat> memory_format) {
@ -211,12 +175,19 @@ at::Tensor LazyNativeFunctions::_to_copy(
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
if (!lazy_self && device && device->type() == c10::kLazy) {
// Case 1: eager->lazy (we create a new lazy tensor)
auto eager_tensor =
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
lazy_self = torch::lazy::GetOrCreateLtcTensor(
eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device));
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
// See Note [Lazy Tensor Functionalization]
// Invariant: if the functionalization key is in the exclude set, then we're
// expected to return an ordinary tensor, which will be "lifted" into a
// functional wrapper later.
bool functionalize_output =
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize);
return torch::lazy::to_lazy_tensor(
self,
options,
*device,
/*non_blocking=*/non_blocking,
/*functionalize_output=*/functionalize_output);
} else if (device && device->type() != c10::kLazy) {
// Case 2: lazy->eager (forces a graph break since we are materializing a
// tensor)
@ -298,24 +269,6 @@ at::Tensor LazyNativeFunctions::_to_copy(
}
};
at::Tensor LazyNativeFunctions::diagonal(
const at::Tensor& self,
int64_t offset,
int64_t dim1,
int64_t dim2) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto input = GetLtcTensor(self);
auto input_shape = input->shape();
dim1 = at::maybe_wrap_dim(dim1, self);
dim2 = at::maybe_wrap_dim(dim2, self);
auto diagonal_info = DiagonalInfo{offset, dim1, dim2};
auto view_info =
ViewInfo(ViewInfo::Type::kDiagonal, input_shape, diagonal_info);
return CreateAtenFromLtcTensor(input->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::empty(
at::IntArrayRef size,
c10::optional<at::ScalarType> dtype,
@ -330,7 +283,18 @@ at::Tensor LazyNativeFunctions::empty(
.pinned_memory(pin_memory)
.dtype(dtype);
auto x_result = at::empty(size, options, memory_format);
return CreateLtcTensor(x_result, GetLtcDevice(device));
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then we're
// expected to return an ordinary tensor, which will be "lifted" into a
// functional wrapper later.
return tensor;
} else {
auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
return wrapped;
}
}
at::Tensor LazyNativeFunctions::empty_strided(
@ -342,16 +306,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0);
}
at::Tensor LazyNativeFunctions::expand(
const at::Tensor& self,
at::IntArrayRef size,
bool implicit) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
return t.as_strided(size, stride, /*storage_offset=*/0);
}
at::Tensor& LazyNativeFunctions::fill_(
@ -412,17 +367,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
indices);
}
at::Tensor LazyNativeFunctions::narrow(
const at::Tensor& self,
int64_t dim,
int64_t start,
int64_t length) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::narrow(self_tensor, dim, start, length));
}
at::Tensor& LazyNativeFunctions::normal_(
at::Tensor& self,
double mean,
@ -453,124 +397,160 @@ at::Tensor& LazyNativeFunctions::normal_(
// std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self;
};
at::Tensor LazyNativeFunctions::permute(
const at::Tensor& self,
at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
}
at::Tensor LazyNativeFunctions::select(
const at::Tensor& self,
int64_t dim,
int64_t index) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::select(torch::lazy::TryGetLtcTensor(self), dim, index));
}
at::Tensor LazyNativeFunctions::slice(
const at::Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::slice(
torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
}
at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::squeeze_(self_tensor);
return self;
}
at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::squeeze_(self_tensor, dim);
return self;
}
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
}
at::Tensor& LazyNativeFunctions::t_(at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::transpose_(self_tensor, 0, 1);
return self;
}
at::Tensor LazyNativeFunctions::transpose(
const at::Tensor& self,
int64_t dim0,
int64_t dim1) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), dim0, dim1));
}
at::Tensor& LazyNativeFunctions::transpose_(
at::Tensor& self,
int64_t dim0,
int64_t dim1) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::transpose_(self_tensor, dim0, dim1);
return self;
}
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
}
at::Tensor& LazyNativeFunctions::unsqueeze_(at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::unsqueeze_(self_tensor, dim);
return self;
}
at::Tensor LazyNativeFunctions::view(
const at::Tensor& self,
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
}
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self,
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
return LazyNativeFunctions::view_copy(self, size);
}
// This is needed by the torch.tensor constructor.
// LazyTensor always opts into functionalization.
// "lifting" a tensor for functionalization means wrapping it in a
// FunctionalTensorWrapper object.
at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
!at::functionalization::impl::isFunctionalTensor(tensor));
return at::functionalization::impl::to_functional_tensor(tensor);
}
// All of the below ops correspond to CompositeExplicitAutograd kernels from
// core that call into view operators internally. These are all composite ops
// that LTC can technically re-use / get for free, but we need to
// "functionalize" them to remove the view ops before we can use them.
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
block_diag)>::call(tensors);
}
at::Tensor LazyNativeFunctions::new_empty_strided(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return at::functionalization::
functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(
self, size, stride, dtype, layout, device, pin_memory);
}
at::Tensor LazyNativeFunctions::narrow_copy(
const at::Tensor& self,
int64_t dim,
int64_t start,
int64_t length) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}
at::Tensor LazyNativeFunctions::pixel_shuffle(
const at::Tensor& self,
int64_t upscale_factor) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_shuffle)>::call(self, upscale_factor);
}
at::Tensor LazyNativeFunctions::pixel_unshuffle(
const at::Tensor& self,
int64_t downscale_factor) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_unshuffle)>::call(self, downscale_factor);
}
at::Tensor LazyNativeFunctions::select_backward(
const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim,
int64_t index) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
select_backward)>::call(grad_output, input_sizes, dim, index);
}
at::Tensor LazyNativeFunctions::_trilinear(
const at::Tensor& i1,
const at::Tensor& i2,
const at::Tensor& i3,
at::IntArrayRef expand1,
at::IntArrayRef expand2,
at::IntArrayRef expand3,
at::IntArrayRef sumdim,
int64_t unroll_dim) {
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
}
::std::tuple<at::Tensor, at::Tensor> LazyNativeFunctions::linalg_inv_ex(
const at::Tensor& self,
bool check_errors) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
linalg_inv_ex)>::call(self, check_errors);
}
at::Tensor LazyNativeFunctions::linalg_pinv(
const at::Tensor& self,
const c10::optional<at::Tensor>& atol,
const c10::optional<at::Tensor>& rtol,
bool hermitian) {
return at::functionalization::functionalize_aten_op<ATEN_OP2(
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
}
// functionalize_aten_op can't handle out= ops directly.
// Instead, we can call the composite kernel from core, and copy and mutations
// back to the inputs.
at::Tensor& LazyNativeFunctions::logsumexp_out(
const at::Tensor& self,
at::IntArrayRef dim,
bool keepdim,
at::Tensor& out) {
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
// directly call the composite kernel from core.
// Make sure to re-enable functionalization first.
auto curr_tls = c10::impl::tls_local_dispatch_key_set();
auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
tls_reenable_functionalize.set_included(curr_tls.included_);
tls_reenable_functionalize.set_excluded(
curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped);
auto out_unwrapped =
at::functionalization::impl::from_functional_tensor(out_wrapped);
// propagate mutations back to the inputs (including resizing)
out.resize_(out_unwrapped.sizes());
out.copy_(out_unwrapped);
return out;
}
at::Tensor LazyNativeFunctions::diagonal_backward(
const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t offset,
int64_t dim1,
int64_t dim2) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
}
at::Tensor LazyNativeFunctions::slice_backward(
const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim,
int64_t start,
int64_t end,
int64_t step) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}
// re-use the composite kernel from core, that way we don't need to provide a
// backwards formula for native_group_norm
std::tuple<Tensor, Tensor, Tensor> LazyNativeFunctions::native_group_norm(
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
double eps) {
return at::native::math_group_norm(
input, weight, bias, N, C, HxW, group, eps);
}
void InitializeAtenBindings() {}

View File

@ -486,6 +486,25 @@ class CUDATestBase(DeviceTypeTestBase):
# Acquires the current device as the primary (test) device
cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device())
# See Note [Lazy Tensor tests in device agnostic testing]
lazy_ts_backend_init = False
class LazyTestBase(DeviceTypeTestBase):
device_type = 'lazy'
def _should_stop_test_suite(self):
return False
@classmethod
def setUpClass(cls):
import torch._lazy
import torch._lazy.metrics
import torch._lazy.ts_backend
global lazy_ts_backend_init
if not lazy_ts_backend_init:
# Need to connect the TS backend to lazy key before running tests
torch._lazy.ts_backend.init()
lazy_ts_backend_init = True
class MPSTestBase(DeviceTypeTestBase):
device_type = 'mps'
@ -570,7 +589,7 @@ PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = 'PYTORCH_TESTING_DEVICE_EXCEPT_FOR'
# The tests in these test cases are derived from the generic tests in
# generic_test_class.
# See note "Generic Device Type Testing."
def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None):
def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None, include_lazy=False):
# Removes the generic test class from its enclosing scope so its tests
# are not discoverable.
del scope[generic_test_class.__name__]
@ -592,6 +611,14 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on
# Filter out the device types based on user inputs
desired_device_type_test_bases = filter_desired_device_types(device_type_test_bases,
except_for, only_for)
if include_lazy:
# Note [Lazy Tensor tests in device agnostic testing]
# Right now, test_view_ops.py runs with LazyTensor.
# We don't want to opt every device-agnostic test into using the lazy device,
# because many of them will fail.
# So instead, the only way to opt a specific device-agnostic test file into
# lazy tensor testing is with include_lazy=True
desired_device_type_test_bases.append(LazyTestBase)
def split_if_not_empty(x: str):
return x.split(",") if len(x) != 0 else []

View File

@ -34,11 +34,12 @@ from torchgen.api.types import (
tensorT,
voidT,
longT,
SymIntT,
symIntArrayRefT,
BaseTypeToCppMapping,
intArrayRefT,
optionalIntArrayRefT,
tensorOptionsT,
symIntArrayRefT,
)
from torchgen import local
from torchgen.utils import assert_never
@ -155,6 +156,11 @@ def argumenttype_type(
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "SymInt":
if remove_non_owning_ref_types:
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
return NamedCType(binds, BaseCType(symIntArrayRefT))
elif str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Scalar":

View File

@ -14,6 +14,8 @@ from torchgen.api.types import (
memoryFormatT,
tensorOptionsT,
scalarTypeT,
SymIntT,
symIntArrayRefT,
boolT,
deviceT,
layoutT,
@ -63,6 +65,7 @@ options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
longVec_ctype = VectorCType(BaseCType(longT))
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
@ -324,7 +327,19 @@ Check this module for more information.
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
elif goal.type == BaseCType(intArrayRefT):
return direct_solve(NamedCType(goal.name, longVec_ctype))
try:
return direct_solve(NamedCType(goal.name, longVec_ctype))
except UnsatError:
# We can also go SymIntArrayRef -> IntArrayRef
symIntArrayRef_type = direct_solve(
NamedCType(goal.name, BaseCType(symIntArrayRefT))
)
return f"c10::asIntArrayRefSlow({symIntArrayRef_type})"
elif goal.type == BaseCType(symIntArrayRefT):
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.expectInt()"
elif goal.type == BaseCType(optionalIntArrayRefT):
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
elif goal.type == BaseCType(optionalScalarRefT):
@ -345,6 +360,10 @@ Check this module for more information.
intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
argname = direct_solve(intArrayRef_ctype)
return f"{argname}.vec()"
if goal.type == VectorCType(BaseCType(SymIntT)):
symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
argname = direct_solve(symIntArrayRef_ctype)
return f"{argname}.vec()"
elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
optionalIntArrayRef_ctype = NamedCType(
goal.name, BaseCType(optionalIntArrayRefT)

View File

@ -1,21 +1,26 @@
from abc import ABC
import itertools
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Tuple
from torchgen.context import method_with_native_function
from torchgen.model import (
FunctionSchema,
Argument,
BackendIndex,
NativeFunction,
NativeFunctionsGroup,
FunctionSchema,
)
from torchgen.api.types import (
BaseCType,
Binding,
DispatcherSignature,
OptionalCType,
VectorCType,
kernel_signature,
deviceT,
)
import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate
from torchgen.api.lazy import (
LazyIrProperties,
LazyIrSchema,
@ -121,6 +126,25 @@ def aten_symbol(schema: LazyIrSchema) -> str:
return schema.aten_name
# converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
context: List[Binding] = []
unwrapped_tensor_args: List[str] = []
for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta"
unwrapped_tensor_args.append(
f"auto {unwrapped_name} = to_meta({arg.name});"
)
context.append(arg.with_name(unwrapped_name))
else:
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
@dataclass(frozen=True)
class GenLazyIR(ABC):
backend_index: BackendIndex
@ -206,7 +230,13 @@ class GenLazyIR(ABC):
node_ctor_args = ", ".join(ctor_args)
scalar_initializers = ",\n ".join(
f"{a.name}({a.name})" for a in scalar_args
[
# This code is just special casing the mapping from string_view -> strings
f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
else f"{a.name}({a.name})"
for a in scalar_args
]
)
if len(scalar_initializers):
scalar_initializers = f",\n {scalar_initializers}"
@ -214,6 +244,8 @@ class GenLazyIR(ABC):
[
f"std::string {a.name};"
if a.lazy_type.cpp_type() == "c10::string_view"
else f"c10::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
else f"{a.lazy_type.cpp_type()} {a.name};"
for a in scalar_args
]
@ -314,19 +346,20 @@ class GenTSLazyIR(GenLazyIR):
elif not schema.properties.CanBeReused:
return ""
value_comparison = []
for arg in schema.positional_values:
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
)
else:
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in schema.positional_scalars:
value_comparison.append(f"this->{arg.name} == {arg.name}")
for arg in schema.keyword_values:
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in schema.keyword_scalars:
value_comparison.append(f"this->{arg.name} == {arg.name}")
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
)
else:
value_comparison.append(f"this->{arg.name} == {arg.name}")
value_comparison_str = " &&\n ".join(value_comparison)
return f"""{signature} {{
@ -428,9 +461,20 @@ class GenLazyNativeFuncDefinition:
all_args = schema.filtered_args()
returns_length = len(schema.returns)
# call the meta kernel if it exists, to compute output shape/dtype for our IR
if func.structured or func.structured_delegate is not None:
meta_out = """std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
# Note [Generated LTC Shape Functions]
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
# we generate a shape function declaration that needs to be manually implemented.
# How do we detect which ops are eligible to use meta tensors?
# In general we should be able to use meta tensors not just on structured operators,
# but also on composite operators that are implemented in terms of structured kernels.
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
# This is the case for all view and view_copy operators however, so we're going to
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
is_view_copy_op = "view_copy" in func.tags
is_structured = func.structured or func.structured_delegate is not None
if is_structured or is_view_copy_op:
meta_out = """
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
if returns_length > 1:
def this_shape(i: int) -> str:
@ -439,8 +483,28 @@ class GenLazyNativeFuncDefinition:
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
shape_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)});
{meta_out}"""
# Convert tensor args to the meta device and call it.
# (We can't pass in the input tensors directly, because they are "functional wrappers".
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
dispatcher_sig = DispatcherSignature.from_schema(func.func)
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
e.expr
for e in translate(
meta_call_ctx, dispatcher_sig.arguments(), method=False
)
]
if is_view_copy_op:
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
assert func.has_composite_explicit_autograd_non_functional_kernel
dispatch_ns = "compositeexplicitautogradnonfunctional"
else:
dispatch_ns = "meta"
shape_str = f"""\
{meta_conversion_str}
auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)});
{meta_out}"""
else:
shape_sig = ComputeShapeSignature(metadata.kernel, func)
shape_str = f"""
@ -571,13 +635,14 @@ class GenLazyShapeInferenceDefinition:
metadata = self.backend_index.get_kernel(f)
assert metadata is not None
# Only generate shape/dtype fn for non-structured kernels,
# since we just use the meta function for structured kernels
if not f.structured and f.structured_delegate is None:
# See Note [Generated LTC Shape Functions]
is_view_copy_op = "view_copy" in f.tags
is_structured = f.structured or f.structured_delegate is not None
if is_structured or is_view_copy_op:
return []
else:
shape_sig = ComputeShapeSignature(metadata.kernel, f)
return ["\n".join([f"{shape_sig.shape_decl};"])]
else:
return []
def generate_non_native_lazy_ir_nodes(

View File

@ -77,6 +77,7 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_composite_view_copy_kernel,
gen_symint_view_copy_kernel,
)
T = TypeVar("T")
@ -2281,6 +2282,24 @@ TORCH_LIBRARY({custom_namespace}, m) {{
)
},
)
view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = []
for g1 in view_groups:
for g2 in view_groups:
if g1.view_copy is None or g2.view_copy is None:
continue
# TODO: make this more first class in the data model
same_base_op = str(g1.view_copy.func.name.name) == str(
g2.view_copy.func.name.name
)
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
if same_base_op and op1_not_symint and op2_symint:
view_copy_with_symint_pairs.append(
(
g1.view_copy,
g2.view_copy,
)
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
@ -2321,6 +2340,12 @@ TORCH_LIBRARY({custom_namespace}, m) {{
"CompositeViewCopyKernel_Definitions": list(
mapMaybe(gen_composite_view_copy_kernel, view_groups)
),
"SymIntViewCopyKernel_Definitions": list(
mapMaybe(
lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]),
view_copy_with_symint_pairs,
)
),
"GeneratedCompositeFunctional_Definitions": list(
mapMaybe(
gen_composite_functional_kernel,

View File

@ -265,7 +265,10 @@ def error_on_missing_kernels(
native_f
)
kernel_defn_regex = rf"{class_name}::([\w\d]*)\([^\)]*\)\s*{{"
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
# here, then we get a nicer error message. If we miss it, you get a linker error.
kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
actual_backend_kernel_name_counts = Counter(
re.findall(kernel_defn_regex, backend_defns)
)

View File

@ -75,8 +75,30 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
if g.view_copy is None:
return None
# For view_copy.SymInt overloads,
# See gen_symint_view_copy_kernel.
if g.view_copy.func.name.overload_name == "SymInt":
return None
# We can make view_copy work in more cases by using reshape()
# when a normal view call would ordinarily fail.
# This also makes LTC more efficient, because they don't need to include
# clone() calls in their graph (which is normally needed by reshape).
if str(g.view_copy.func.name) == "view_copy":
return """\
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
if (!at::detail::computeStride(self.sizes(), self.strides(), size).has_value()) {
return self.reshape(size);
} else {
auto output = at::_ops::view::call(self, size);
return output.clone();
}
}
"""
# view_copy is a native signature, since we're generating an at::native:: kernel
view_copy_sig = NativeSignature(g.view_copy.func)
# view is a dispatcher signature, since we're calling into the at::_ops API
view_sig = DispatcherSignature(g.view.func)
@ -113,6 +135,34 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
"""
# For symint view copy kernels, we want to generate them to call into
# their concrete view_copy counterparts.
@with_native_function_and
def gen_symint_view_copy_kernel(
view_copy: NativeFunction, view_copy_symint: NativeFunction
) -> str:
# view_copy.symint is a native signature, since we're generating an at::native:: kernel
view_copy_symint_sig = NativeSignature(view_copy_symint.func)
# view_copy is a dispatcher signature, since we're calling into the at::_ops API
view_copy_sig = DispatcherSignature(view_copy.func)
exprs = ", ".join(
[
e.expr
for e in translate(
view_copy_symint_sig.arguments(), view_copy_sig.arguments()
)
]
)
return f"""
{view_copy_symint_sig.defn()} {{
return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs});
}}
"""
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:

View File

@ -162,6 +162,39 @@ and implement it in the the corresponding shape_inference.cpp file.\n
)
# Some helper functions for the codegen.
def get_ltc_helper_fns() -> str:
return """\
at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
if (!tensor.defined()) return tensor;
auto out = at::native::empty_strided_meta(tensor.sizes(), tensor.strides(), \
/*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \
/*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt);
// needs to handle wrapped numbers, so dtype promotion works properly.
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
out.unsafeGetTensorImpl()->set_wrapped_number(true);
}
return out;
}
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
if (tensor.has_value()) {
return to_meta(*tensor);
}
return c10::nullopt;
}
std::vector<at::Tensor> to_meta(const at::TensorList& t_list) {
std::vector<at::Tensor> outs;
outs.reserve(t_list.size());
for (const auto& i : c10::irange(t_list.size())) {
outs.push_back(to_meta(t_list[i]));
}
return outs;
}
"""
class default_args:
node_base: str = "Node"
node_base_hdr: Optional[str] = None
@ -436,6 +469,9 @@ def run_gen_lazy_tensor(
tensor_class_hdr,
shape_inference_hdr,
"ATen/Functions.h",
"ATen/native/TensorConversions.h",
"ATen/NativeFunctions.h",
"ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
"ATen/MetaFunctions.h",
"ATen/Operators.h",
"ATen/native/CPUFallback.h",
@ -452,6 +488,7 @@ def run_gen_lazy_tensor(
else []
)
],
"helper_fns": get_ltc_helper_fns(),
"native_functions_include": "",
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,