Codegen Non-Native IR Nodes (#76535)

Add codegen infrastructure to generate IR nodes for non-native ops.

The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
    ...
    - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
    ...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.

Fixes #74628

CC: @wconstab @desertfire @henrytwo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535
Approved by: https://github.com/wconstab
This commit is contained in:
Antonio Kim
2022-05-24 19:29:23 +00:00
committed by PyTorch MergeBot
parent 13dcba8c07
commit 02c4d877b4
55 changed files with 496 additions and 1347 deletions

View File

@ -1865,6 +1865,7 @@ test_suite(
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",

View File

@ -178,3 +178,41 @@ supported:
- _unsafe_view
autograd:
- max_pool3d
# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
non_native:
- func: scalar(Scalar value, ScalarType type) -> Tensor
opkind: at::prim::Constant
properties:
- ShapeCompute
- TreatScalarsAsConstants
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, int[] output_size) -> Tensor
properties:
- ShapeCompute
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
opkind: ltc_cast
properties:
- ShapeCompute
# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
opkind: ltc_diagonal_view_update
properties:
- ShapeCompute
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
- func: permute(Tensor input, int[] dims) -> Tensor
- func: resize(Tensor input, int[] size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
opkind: ltc_select_view_update
properties:
- ShapeCompute
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor

View File

@ -0,0 +1,11 @@
#pragma once
${lazy_non_native_ir_inc}
// This file contains autogenerated LazyTensor Non Native IR nodes
${namespace_prologue}
${non_native_ir_nodes}
${namespace_epilogue}

View File

@ -28,6 +28,7 @@ def define_targets(rules):
":DispatchKeyNativeFunctions.cpp",
":DispatchKeyNativeFunctions.h",
":LazyIr.h",
":LazyNonNativeIr.h",
":RegisterDispatchKey.cpp",
":native_functions.yaml",
":shape_inference.h",
@ -88,6 +89,7 @@ GENERATED_TESTING_PY = [
GENERATED_LAZY_H = [
"torch/csrc/lazy/generated/LazyIr.h",
"torch/csrc/lazy/generated/LazyNonNativeIr.h",
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
]

View File

@ -380,6 +380,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
list(APPEND GENERATED_H_TORCH
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNonNativeIr.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.h"
)
endif()
@ -444,6 +445,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TOOLS_PATH}/autograd/templates/VariableType.h"
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"

View File

@ -417,35 +417,19 @@ lazy_tensor_core_sources = [
# We can't build all of the ts backend under certain build configurations, e.g. mobile,
# since it depends on things like autograd, meta functions, which may be disabled
lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/cast.cpp",
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/narrow.cpp",
"torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/permute.cpp",
"torch/csrc/lazy/ts_backend/view_ops/resize.cpp",
"torch/csrc/lazy/ts_backend/view_ops/select.cpp",
"torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp",
"torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp",
"torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/view.cpp",
"torch/csrc/lazy/ts_backend/ts_node.cpp",
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
"torch/csrc/lazy/ts_backend/ts_backend_impl.cpp",
"torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp",
"torch/csrc/lazy/ts_backend/ts_lowering_context.cpp",
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
"torch/csrc/lazy/ts_backend/ts_node.cpp",
"torch/csrc/lazy/ts_backend/ts_node_lowering.cpp",
]

View File

@ -237,7 +237,7 @@ invalid_key: invalid_val"""
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
self.assertExpectedInline(
output_error,
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""", # noqa: B950
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950
)
# if use_out_as_primary is provided, it must be a bool

View File

@ -19,6 +19,10 @@ hash_t Output::hash() const {
return HashCombine(node->hash(), Hash(index));
}
hash_t Output::shapeHash() const {
return HashCombine(node->shapeHash(), Hash(index));
}
std::string Output::ToString() const {
std::stringstream ss;
ss << node->ToString() << ", index=" << index;
@ -144,7 +148,7 @@ std::string Node::ToString() const {
void Node::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_.push_back(node);
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}

View File

@ -214,6 +214,7 @@ struct TORCH_API Output {
: node(node), index(index) {}
hash_t hash() const;
hash_t shapeHash() const;
bool operator==(const Output& rhs) const {
return node == rhs.node && index == rhs.index;

View File

@ -6,33 +6,33 @@
namespace torch {
namespace lazy {
bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
TORCH_API bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
TORCH_API std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
Shape MakeDiagonalShape(
TORCH_API Shape MakeDiagonalShape(
const Shape& shape,
int64_t offset,
int64_t dim1,
int64_t dim2);
Shape MakePermuteShape(
TORCH_API Shape MakePermuteShape(
const Shape& source_shape,
c10::ArrayRef<int64_t> permutation);
Shape MakeSelectShape(
TORCH_API Shape MakeSelectShape(
const Shape& shape,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride);
int64_t GetStride(int64_t start, int64_t end, int64_t stride);
TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride);
std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
TORCH_API std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
int64_t squeeze_dim);
std::vector<int64_t> BuildUnsqueezedDimensions(
TORCH_API std::vector<int64_t> BuildUnsqueezedDimensions(
c10::ArrayRef<int64_t> dimensions,
int64_t squeeze_dim);

View File

@ -23,8 +23,11 @@ std::vector<typename Container::value_type> PermuteDimensions(
const Container& dimensions) {
using T = typename Container::value_type;
TORCH_CHECK(
dimensions.size() == permutation.size() && IsPermutation(permutation),
"Invalid permutation specified");
dimensions.size() == permutation.size(),
"Invalid permutation specified. dimensions.size() != permutation.size() (", dimensions.size(), " vs. ", permutation.size(), ")");
TORCH_CHECK(
IsPermutation(permutation),
"Invalid permutation specified. Permutation is not permutation");
std::vector<T> output(dimensions.size());
for (const auto i : c10::irange(permutation.size())) {
output[i] = dimensions[permutation[i]];

View File

@ -45,10 +45,12 @@
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/shape.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/InferSize.h>
#include <ATen/WrapDimUtils.h>
#include <aten/src/ATen/native/ReduceOpsUtils.h>
#include <c10/core/ScalarType.h>
@ -629,6 +631,67 @@ std::vector<Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t di
return {Shape(self.scalar_type(), self.sizes().vec())};
}
// Non-Native Ops
std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type) {
return { Shape(type, {}) };
}
std::vector<Shape> compute_shape_expand(const Output& input, const std::vector<int64_t>& size, const bool& is_scalar_expand) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_view(const Output& input, const std::vector<int64_t>& output_sizes) {
const Shape& input_shape = input.shape();
const auto complete_output_sizes =
at::infer_size(output_sizes, input_shape.numel());
return { Shape(input_shape.scalar_type(), complete_output_sizes) };
}
std::vector<Shape> compute_shape_cast(const Output& input, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype) {
Shape shape = input.shape();
shape.set_scalar_type(dtype);
return { shape };
}
// View Ops
std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
return { Shape(target.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
return { target.shape() };
}
std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
return { MakeDiagonalShape(input.shape(), offset, dim1, dim2) };
}
std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices) {
return { input.shape() };
}
std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes) {
return { Shape(input.shape().scalar_type(), sizes) };
}
std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims) {
return { MakePermuteShape(input.shape(), dims) };
}
std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
return { target.shape() };
}
std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
return { MakeSelectShape(input.shape(), dim, start, end, stride) };
}
std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
const auto& input_shape = input.shape();
return { torch::lazy::Shape(input_shape.scalar_type(), BuildSqueezedDimensions(input_shape.sizes(), dim)) };
}
std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim) {
const auto& input_shape = input.shape();
return { torch::lazy::Shape(input_shape.scalar_type(), BuildUnsqueezedDimensions(input_shape.sizes(), dim)) };
}
// Restore unused-parameters warnings
#pragma GCC diagnostic pop

View File

@ -4,6 +4,7 @@
#include <c10/core/ScalarType.h>
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include <vector>
@ -68,5 +69,26 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tenso
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
// Non-Native ops
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
TORCH_API std::vector<Shape> compute_shape_expand(const Output& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand);
TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std::vector<int64_t>& output_sizes);
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);
// View Ops
TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
TORCH_API std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
TORCH_API std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices);
TORCH_API std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes);
TORCH_API std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims);
TORCH_API std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size);
TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);
} // namespace lazy
} // namespace torch

View File

@ -4,31 +4,11 @@
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <torch/csrc/lazy/ts_backend/view_ops/narrow.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/permute.h>
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/resize.h>
#include <torch/csrc/lazy/ts_backend/view_ops/squeeze.h>
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal.h>
#include <torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
#include <torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
#include <torch/csrc/lazy/ts_backend/view_ops/view.h>
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
// This file contains the TorchScript IrBuilder
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
namespace torch {
namespace lazy {

View File

@ -1,42 +0,0 @@
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
#include <torch/csrc/lazy/core/tensor_util.h>
namespace torch {
namespace lazy {
namespace {
Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
Shape shape = input.shape();
shape.set_scalar_type(type);
return shape;
}
} // namespace
Cast::Cast(
const Value& input,
at::ScalarType dtype,
c10::optional<at::ScalarType> stype)
: TsNode(
ClassOpKind(),
{input},
{NodeOutputShape(input, dtype)},
/*num_outputs=*/1,
MHash(101, static_cast<int>(dtype), OptionalOr<int>(stype, -1))),
dtype_(dtype),
stype_(stype) {}
std::string Cast::ToString() const {
std::stringstream ss;
ss << TsNode::ToString();
ss << ", dtype=" << dtype_;
if (stype_) {
ss << ", stype=" << *stype_;
}
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,38 +0,0 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Cast : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_cast;
}
Cast(
const Value& input,
at::ScalarType dtype,
c10::optional<at::ScalarType> stype = c10::nullopt);
std::string ToString() const override;
at::ScalarType dtype() const {
return dtype_;
}
const c10::optional<at::ScalarType>& stype() const {
return stype_;
}
private:
at::ScalarType dtype_;
c10::optional<at::ScalarType> stype_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,29 +0,0 @@
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
namespace torch {
namespace lazy {
Expand::Expand(
const Value& input,
std::vector<int64_t> size,
bool is_scalar_expand)
: TsNode(
ClassOpKind(),
{input},
/*num_outputs=*/1,
MHash(size, is_scalar_expand)),
size_(std::move(size)),
is_scalar_expand_(is_scalar_expand) {
addComputedShape(
[&]() { return Shape(input.shape().scalar_type(), size_); });
}
std::string Expand::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
<< "), is_scalar_expand=" << is_scalar_expand_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,37 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <vector>
namespace torch {
namespace lazy {
class TORCH_API Expand : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::expand);
}
Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);
std::string ToString() const override;
const std::vector<int64_t>& size() const {
return size_;
}
bool is_scalar_expand() const {
return is_scalar_expand_;
}
private:
std::vector<int64_t> size_;
// True iff the input was a scalar and this was generated internally by a
// lowering and not by user action. For some backends, this difference can be
// material (for example setting strides according to eager semantics).
bool is_scalar_expand_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,40 +0,0 @@
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
#include <functional>
#include <sstream>
#include <ATen/core/Formatting.h>
namespace torch {
namespace lazy {
using at::operator<<;
Scalar::Scalar(const at::Scalar& value, Shape shape)
: TsNode(
ClassOpKind(),
std::move(shape),
/*num_outputs=*/1,
ScalarHash(value)),
value_(value) {}
Scalar::Scalar(const at::Scalar& value, c10::ScalarType type)
: TsNode(
ClassOpKind(),
{Shape(type, {})},
/*num_outputs=*/1,
ScalarHash(value)),
value_(value) {}
std::string Scalar::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", value=" << value_;
return ss.str();
}
hash_t ScalarHash(const at::Scalar& s) {
return s.isFloatingPoint() ? Hash(s.toDouble()) : Hash(s.toLong());
}
} // namespace lazy
} // namespace torch

View File

@ -1,35 +0,0 @@
#pragma once
#include <c10/core/Scalar.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
// Differently from Constant, this is a scalar value broadcasted to a shape.
// Even though a Constant could have been used, for simple scalars broadcasted
// to big shapes, the Constant leads to big literals expanded within the
// computation graph.
class TORCH_API Scalar : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::prim::Constant);
}
Scalar(const at::Scalar& value, Shape shape);
Scalar(const at::Scalar& value, c10::ScalarType type);
std::string ToString() const override;
const at::Scalar& value() const {
return value_;
}
private:
at::Scalar value_;
};
TORCH_API hash_t ScalarHash(const at::Scalar& s);
} // namespace lazy
} // namespace torch

View File

@ -17,29 +17,29 @@ namespace torch {
namespace lazy {
hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes) {
hash_t OperandHashes(const OpList& operands,
const c10::ArrayRef<Shape>& shapes,
const hash_t& seed, bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash = operand.hash();
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
hash = HashCombine(hash, operand_hash);
}
for (auto& shape : shapes) {
hash = HashCombine(hash, shape.hash(bakeInSizes));
}
return hash;
}
hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}
TsNode::TsNode(OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs, hash_t hash_seed)
: Node(op, operands, std::move(shapes), num_outputs) {
hash_seed = HashCombine(op.hash(), hash_seed);
shape_hash_ = OperandHashes(operands, hash_seed, true);
dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, hash_seed, false) : shape_hash_);
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, this->shapes(), hash_seed, false) : shape_hash_);
}
@ -53,11 +53,7 @@ TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs),
shape_hash_(GetOpHash(op, shape, hash_seed, true)),
dag_hash_(enableDynamicShape() ? GetOpHash(op, shape, hash_seed, false) : shape_hash_) {
shapes_.push_back(std::move(shape));
}
: TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
hash_t TsNode::hash() const { return dag_hash_; }
@ -81,8 +77,8 @@ TensorList::TensorList(OpList values)
: TsNode(/*op=*/ClassOpKind(),
/*operands=*/values,
/*shapes=*/std::vector<Shape>(),
/*num_outputs=*/1,
/*hash_seed=*/OperandHashes(values, /*seed=*/kHashSeed, enableDynamicShape())) {}
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
TSOpVector TensorList::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
@ -97,5 +93,7 @@ TSOpVector TensorList::Lower(std::shared_ptr<torch::jit::GraphFunction> function
return {listnode->output()};
}
} // namespace lazy
} // namespace torch

View File

@ -146,9 +146,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
arguments.emplace_back(node->stride());
arguments.emplace_back(node->storage_offset());
arguments.emplace_back(node->size);
arguments.emplace_back(node->stride);
arguments.emplace_back(node->storage_offset);
TSOpVector as_strided_out = LowerBuiltin(node, arguments);
CHECK_EQ(as_strided_out.size(), 1);
return {GenerateClone(as_strided_out.front())};
@ -165,8 +165,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
dest_arguments.emplace_back(destination);
dest_arguments.emplace_back(
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
dest_arguments.emplace_back(node->stride());
dest_arguments.emplace_back(node->storage_offset());
dest_arguments.emplace_back(node->stride);
dest_arguments.emplace_back(node->storage_offset
);
TSOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
@ -209,16 +210,16 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerCast(const torch::lazy::Cast* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype());
arguments.emplace_back(node->dtype);
return LowerBuiltin(at::aten::to, arguments);
}
TSOpVector LowerExpand(const torch::lazy::Expand* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
arguments.emplace_back(node->size);
auto expand_out = LowerBuiltin(node, arguments);
if (node->is_scalar_expand()) {
if (node->is_scalar_expand) {
// The aten::expand operations sets all strides to 0 when the original is
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
@ -232,8 +233,8 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerNarrow(const torch::lazy::Narrow* node) {
const torch::lazy::Output& input = node->operand(0);
torch::jit::Value* base = loctx()->GetOutputOp(input);
const auto& base_indices = node->base_indices();
const auto& sizes = node->sizes();
const auto& base_indices = node->base_indices;
const auto& sizes = node->sizes;
const torch::lazy::Shape& input_shape = input.shape();
CHECK_EQ(sizes.size(), base_indices.size());
CHECK_EQ(input_shape.dim(), base_indices.size());
@ -248,12 +249,12 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerPermute(const torch::lazy::Permute* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dims());
arguments.emplace_back(node->dims);
return LowerBuiltin(node, arguments);
}
TSOpVector LowerScalar(const torch::lazy::Scalar* node) {
const at::Scalar& value = node->value();
const at::Scalar& value = node->value;
const torch::lazy::Shape& shape = node->shape();
auto options =
at::TensorOptions()
@ -264,19 +265,19 @@ class TSNodeLowering : public TSNodeLoweringInterface {
}
TSOpVector LowerSelect(const torch::lazy::Select* node) {
int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(),
node->stride());
int64_t step = torch::lazy::GetStride(node->start, node->end,
node->stride);
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
return {GenerateSlice(/*base=*/base, /*dim=*/node->dim(),
/*start=*/node->start(), /*end=*/node->end(),
return {GenerateSlice(/*base=*/base, /*dim=*/node->dim,
/*start=*/node->start, /*end=*/node->end,
/*step=*/step)};
}
TSOpVector LowerSqueeze(const Squeeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
if (node->dim() != -1) {
arguments.emplace_back(node->dim());
if (node->dim != -1) {
arguments.emplace_back(node->dim);
}
return LowerBuiltin(node, arguments);
}
@ -284,11 +285,11 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(),
node->stride());
int64_t step = torch::lazy::GetStride(node->start, node->end,
node->stride);
torch::jit::Value* selected = GenerateSlice(
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
/*end=*/node->end(), /*step=*/step);
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
/*end=*/node->end, /*step=*/step);
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
return {dest};
}
@ -296,7 +297,7 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const auto& base_indices = node->base_indices();
const auto& base_indices = node->base_indices;
const torch::lazy::Output& source_argument = node->operand(1);
const torch::lazy::Shape& source_shape = source_argument.shape();
CHECK_EQ(source_shape.dim(), base_indices.size());
@ -314,23 +315,23 @@ class TSNodeLowering : public TSNodeLoweringInterface {
TSOpVector LowerUnsqueeze(const Unsqueeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dim());
arguments.emplace_back(node->dim);
return LowerBuiltin(node, arguments);
}
TSOpVector LowerView(const torch::lazy::View* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->output_size());
arguments.emplace_back(node->output_size);
return LowerBuiltin(at::aten::reshape, arguments);
}
TSOpVector LowerDiagonal(const Diagonal* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->offset());
arguments.emplace_back(node->dim1());
arguments.emplace_back(node->dim2());
arguments.emplace_back(node->offset);
arguments.emplace_back(node->dim1);
arguments.emplace_back(node->dim2);
return LowerBuiltin(node, arguments);
}
@ -346,9 +347,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
// Replay the diagonal.
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(node->offset());
arguments.emplace_back(node->dim1());
arguments.emplace_back(node->dim2());
arguments.emplace_back(node->offset);
arguments.emplace_back(node->dim1);
arguments.emplace_back(node->dim2);
auto diag = LowerBuiltin(at::aten::diagonal, arguments);
// Update the replayed diagonal view with the input.

View File

@ -1,52 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
#include <algorithm>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/util.h>
namespace torch {
namespace lazy {
AsStrided::AsStrided(
const Value& input,
std::vector<int64_t> size,
std::vector<int64_t> stride,
int64_t storage_offset)
: TsNode(
ClassOpKind(),
{input},
[&]() {
return Shape(input.shape().scalar_type(), size);
},
/*num_outputs=*/1,
MHash(size, stride, storage_offset)),
size_(std::move(size)),
stride_(std::move(stride)),
storage_offset_(storage_offset) {}
std::string AsStrided::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
<< "), stride=(" << c10::Join(", ", stride_)
<< "), storage_offset=" << storage_offset_;
return ss.str();
}
bool AsStrided::StrideIsSupported(c10::ArrayRef<int64_t> stride) {
std::vector<int64_t> sorted_stride(stride.begin(), stride.end());
std::sort(sorted_stride.begin(), sorted_stride.end());
return stride.empty() || sorted_stride.front() == 1;
}
std::vector<int64_t> AsStrided::GetArrayStridePermutation(
c10::ArrayRef<int64_t> stride) {
std::vector<int64_t> permutation = Iota<int64_t>(stride.size());
std::sort(permutation.begin(), permutation.end(), [&](int64_t a, int64_t b) {
return stride[a] > stride[b];
});
return permutation;
}
} // namespace lazy
} // namespace torch

View File

@ -1,48 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <vector>
namespace torch {
namespace lazy {
class TORCH_API AsStrided : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::as_strided);
}
AsStrided(
const Value& input,
std::vector<int64_t> size,
std::vector<int64_t> stride,
int64_t storage_offset);
std::string ToString() const override;
const std::vector<int64_t>& size() const {
return size_;
}
const std::vector<int64_t>& stride() const {
return stride_;
}
int64_t storage_offset() const {
return storage_offset_;
}
static bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
static std::vector<int64_t> GetArrayStridePermutation(
c10::ArrayRef<int64_t> stride);
private:
std::vector<int64_t> size_;
std::vector<int64_t> stride_;
int64_t storage_offset_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,37 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
namespace torch {
namespace lazy {
AsStridedViewUpdate::AsStridedViewUpdate(
const Value& target,
const Value& input,
std::vector<int64_t> size,
std::vector<int64_t> stride,
int64_t storage_offset)
: TsNode(
ltc_as_strided_view_update,
{target, input},
[&]() {
return Shape(target.shape().scalar_type(), size);
},
/*num_outputs=*/1,
MHash(size, stride, storage_offset)),
size_(std::move(size)),
stride_(std::move(stride)),
storage_offset_(storage_offset) {}
std::string AsStridedViewUpdate::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
<< "), stride=(" << c10::Join(", ", stride_)
<< "), storage_offset=" << storage_offset_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,45 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <vector>
#include <lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API AsStridedViewUpdate : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_as_strided_view_update;
}
AsStridedViewUpdate(
const Value& target,
const Value& input,
std::vector<int64_t> size,
std::vector<int64_t> stride,
int64_t storage_offset);
std::string ToString() const override;
const std::vector<int64_t>& size() const {
return size_;
}
const std::vector<int64_t>& stride() const {
return stride_;
}
int64_t storage_offset() const {
return storage_offset_;
}
private:
std::vector<int64_t> size_;
std::vector<int64_t> stride_;
int64_t storage_offset_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,35 +0,0 @@
#include <c10/util/irange.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal.h>
#include <cmath>
namespace torch {
namespace lazy {
Diagonal::Diagonal(
const Value& input,
int64_t offset,
int64_t dim1,
int64_t dim2)
: TsNode(
ClassOpKind(),
{input},
[&]() {
return MakeDiagonalShape(input.shape(), offset, dim1, dim2);
},
/*num_outputs=*/1,
MHash(offset, dim1, dim2)),
offset_(offset),
dim1_(dim1),
dim2_(dim2) {}
std::string Diagonal::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_
<< ", dim2=" << dim2_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,37 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Diagonal : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::diagonal);
}
Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2);
std::string ToString() const override;
int64_t offset() const {
return offset_;
}
int64_t dim1() const {
return dim1_;
}
int64_t dim2() const {
return dim2_;
}
private:
int64_t offset_;
int64_t dim1_;
int64_t dim2_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,32 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
DiagonalViewUpdate::DiagonalViewUpdate(
const Value& target,
const Value& input,
int64_t offset,
int64_t dim1,
int64_t dim2)
: TsNode(
ltc_diagonal_view_update,
{target, input},
{target.shape()},
/*num_outputs=*/1,
MHash(offset, dim1, dim2)),
offset_(offset),
dim1_(dim1),
dim2_(dim2) {}
std::string DiagonalViewUpdate::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_
<< ", dim2=" << dim2_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,43 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API DiagonalViewUpdate : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_diagonal_view_update;
}
DiagonalViewUpdate(
const Value& target,
const Value& input,
int64_t offset,
int64_t dim1,
int64_t dim2);
std::string ToString() const override;
int64_t offset() const {
return offset_;
}
int64_t dim1() const {
return dim1_;
}
int64_t dim2() const {
return dim2_;
}
private:
int64_t offset_;
int64_t dim1_;
int64_t dim2_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,33 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/narrow.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
Narrow::Narrow(
const Value& input,
c10::ArrayRef<int64_t> base_indices,
c10::ArrayRef<int64_t> sizes)
: TsNode(
ClassOpKind(),
{input},
/*num_outputs=*/1,
MHash(base_indices, sizes)),
base_indices_(base_indices.begin(), base_indices.end()),
sizes_(sizes.begin(), sizes.end()) {
addComputedShape([&]() {
return Shape(operand(0).shape().scalar_type(), sizes);
});
}
std::string Narrow::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", base_indices=("
<< c10::Join(", ", base_indices_) << "), sizes=("
<< c10::Join(", ", sizes_) << ")";
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,35 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Narrow : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::narrow);
}
Narrow(
const Value& input,
c10::ArrayRef<int64_t> base_indices,
c10::ArrayRef<int64_t> sizes);
std::string ToString() const override;
const std::vector<int64_t>& base_indices() const {
return base_indices_;
}
const std::vector<int64_t>& sizes() const {
return sizes_;
}
private:
std::vector<int64_t> base_indices_;
std::vector<int64_t> sizes_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,29 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
NarrowViewUpdate::NarrowViewUpdate(
const Value& input,
const Value& source,
c10::ArrayRef<int64_t> base_indices)
: TsNode(
ltc_narrow_view_update,
{input, source},
/*num_outputs=*/1,
MHash(base_indices)),
base_indices_(base_indices.begin(), base_indices.end()) {
addComputedShape([&]() { return operand(0).shape(); });
}
std::string NarrowViewUpdate::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", base_indices=("
<< c10::Join(", ", base_indices_) << ")";
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,31 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API NarrowViewUpdate : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_narrow_view_update;
}
NarrowViewUpdate(
const Value& input,
const Value& source,
c10::ArrayRef<int64_t> base_indices);
std::string ToString() const override;
const std::vector<int64_t>& base_indices() const {
return base_indices_;
}
private:
std::vector<int64_t> base_indices_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,28 +0,0 @@
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/ts_backend/view_ops/permute.h>
#include <torch/csrc/lazy/core/helpers.h>
namespace torch {
namespace lazy {
Permute::Permute(const Value& input, std::vector<int64_t> dims)
: TsNode(
ClassOpKind(),
{input},
/*num_outputs=*/1,
MHash(dims)),
dims_(std::move(dims)) {
addComputedShape([&]() {
return MakePermuteShape(operand(0).shape(), dims_);
});
}
std::string Permute::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", dims=(" << c10::Join(", ", dims_) << ")";
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,28 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Permute : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::permute);
}
Permute(const Value& input, std::vector<int64_t> dims);
std::string ToString() const override;
const std::vector<int64_t>& dims() const {
return dims_;
}
private:
// The permutation of dimensions.
std::vector<int64_t> dims_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,29 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/resize.h>
namespace torch {
namespace lazy {
namespace {
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> size) {
return Shape(input.shape().scalar_type(), size);
}
} // namespace
Resize::Resize(const Value& input, std::vector<int64_t> size)
: TsNode(
ClassOpKind(),
{input},
[&]() { return NodeOutputShape(input, size); },
/*num_outputs=*/1,
MHash(size)),
size_(std::move(size)) {}
std::string Resize::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) << ")";
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,27 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Resize : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::resize);
}
Resize(const Value& input, std::vector<int64_t> size);
std::string ToString() const override;
const std::vector<int64_t>& size() const {
return size_;
}
private:
std::vector<int64_t> size_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,44 +0,0 @@
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
Select::Select(
const Value& input,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride)
: TsNode(
ClassOpKind(),
{input},
[&]() {
return MakeSelectShape(input.shape(), dim, start, end, stride);
},
/*num_outputs=*/1,
MHash(dim, start, end, stride)),
dim_(dim),
start_(start),
end_(end),
stride_(stride) {}
std::string Select::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_
<< ", end=" << end_ << ", stride=" << stride_;
return ss.str();
}
int64_t Select::GetStride(int64_t start, int64_t end, int64_t stride) {
if (stride == 0) {
CHECK_EQ(start, end);
stride = 1;
}
return stride;
}
} // namespace lazy
} // namespace torch

View File

@ -1,49 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Select : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::select);
}
Select(
const Value& input,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride);
std::string ToString() const override;
int64_t dim() const {
return dim_;
}
int64_t start() const {
return start_;
}
int64_t end() const {
return end_;
}
int64_t stride() const {
return stride_;
}
static int64_t GetStride(int64_t start, int64_t end, int64_t stride);
private:
int64_t dim_;
int64_t start_;
int64_t end_;
int64_t stride_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,36 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/select_view_update.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
namespace torch {
namespace lazy {
SelectViewUpdate::SelectViewUpdate(
const Value& target,
const Value& source,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride)
: TsNode(
ltc_select_view_update,
{target, source},
{target.shape()},
/*num_outputs=*/1,
MHash(dim, start, end, stride)),
dim_(dim),
start_(start),
end_(end),
stride_(stride) {}
std::string SelectViewUpdate::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_
<< ", end=" << end_ << ", stride=" << stride_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,49 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API SelectViewUpdate : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_select_view_update;
}
SelectViewUpdate(
const Value& target,
const Value& source,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride);
std::string ToString() const override;
int64_t dim() const {
return dim_;
}
int64_t start() const {
return start_;
}
int64_t end() const {
return end_;
}
int64_t stride() const {
return stride_;
}
private:
int64_t dim_;
int64_t start_;
int64_t end_;
int64_t stride_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,28 +0,0 @@
#include <c10/util/irange.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/ts_backend/view_ops/squeeze.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
namespace torch {
namespace lazy {
Squeeze::Squeeze(const torch::lazy::Value& input, int dim)
: torch::lazy::TsNode(ClassOpKind(), {input},
/*num_outputs=*/1, torch::lazy::MHash(dim)),
dim_(dim) {
addComputedShape(
[&]() {
const auto& input_shape = input.shape();
return torch::lazy::Shape(input_shape.scalar_type(),
BuildSqueezedDimensions(input_shape.sizes(), dim));
});
}
std::string Squeeze::ToString() const {
std::stringstream ss;
ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,26 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Squeeze : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::squeeze);
}
// Squeeze out the specified dimension index, -1 for all trivial dimensions.
Squeeze(const torch::lazy::Value& input, int dim);
std::string ToString() const override;
int dim() const { return dim_; }
private:
int dim_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,30 +0,0 @@
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
namespace torch {
namespace lazy {
Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim)
: torch::lazy::TsNode(
ClassOpKind(),
{input},
/*num_outputs=*/1,
torch::lazy::MHash(dim)),
dim_(dim) {
addComputedShape([&]() {
const auto& input_shape = input.shape();
return torch::lazy::Shape(
input_shape.scalar_type(),
BuildUnsqueezedDimensions(input_shape.sizes(), dim));
});
}
std::string Unsqueeze::ToString() const {
std::stringstream ss;
ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,27 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API Unsqueeze : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::unsqueeze);
}
Unsqueeze(const torch::lazy::Value& input, int dim);
std::string ToString() const override;
int dim() const {
return dim_;
}
private:
int dim_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,35 +0,0 @@
#include <torch/csrc/lazy/ts_backend/view_ops/view.h>
#include <ATen/InferSize.h>
namespace torch {
namespace lazy {
namespace {
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
const Shape& input_shape = input.shape();
const auto complete_output_sizes =
at::infer_size(output_sizes, input_shape.numel());
return Shape(input_shape.scalar_type(), complete_output_sizes);
}
} // namespace
View::View(const Value& input, std::vector<int64_t> output_size)
: TsNode(
ClassOpKind(),
{input},
{NodeOutputShape(input, output_size)},
/*num_outputs=*/1,
MHash(output_size)),
output_size_(std::move(output_size)) {}
std::string View::ToString() const {
std::stringstream ss;
ss << TsNode::ToString() << ", output_size=(" << c10::Join(", ", output_size_)
<< ")";
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,29 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <vector>
namespace torch {
namespace lazy {
class TORCH_API View : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::view);
}
View(const Value& input, std::vector<int64_t> output_size);
std::string ToString() const override;
const std::vector<int64_t>& output_size() const {
return output_size_;
}
private:
std::vector<int64_t> output_size_;
};
} // namespace lazy
} // namespace torch

View File

@ -1,4 +1,5 @@
from typing import List, Union, Tuple, Optional
from typing import Any, Dict, List, Union, Tuple, Optional
from torchgen.model import (
Type,
BaseTy,
@ -56,7 +57,7 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type,
typ: Type, properties: "LazyIrProperties"
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
"""
This function takes a type from NativeFunctions and converts it for use with
@ -77,6 +78,8 @@ def process_ir_type(
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
if properties.TreatScalarsAsConstants:
return BaseCType(scalarT)
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
@ -101,7 +104,7 @@ def process_ir_type(
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem))
return OptionalCType(process_ir_type(typ.elem, properties))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
@ -110,12 +113,12 @@ def process_ir_type(
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
else:
return VectorCType(process_ir_type(typ.elem))
return VectorCType(process_ir_type(typ.elem, properties))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
def isValueType(typ: CType) -> bool:
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
@ -123,9 +126,14 @@ def isValueType(typ: CType) -> bool:
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
return (
typ.type == getValueT()
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem)
return isValueType(typ.elem, properties)
return False
@ -167,7 +175,7 @@ class LazyArgument:
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(self, arg: Argument):
def __init__(self, arg: Argument, properties: "LazyIrProperties"):
self.name = arg.name
self.orig_type = arg.type
self.is_optional = isinstance(arg.type, OptionalType)
@ -181,11 +189,13 @@ class LazyArgument:
# its null and safe to exclude from lazy IR
self.lazy_type_ = None
else:
self.lazy_type_ = process_ir_type(arg.type)
self.lazy_type_ = process_ir_type(arg.type, properties)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = isSymIntType(arg.type)
self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)
self.is_lazy_value = not self.is_generator and isValueType(
self.lazy_type, properties
)
@property
def lazy_type(self) -> CType:
@ -195,6 +205,64 @@ class LazyArgument:
return self.lazy_type_
class LazyIrProperties:
"""Collection of properties for an IR node
The property groups are listed below. Each group is mutually
exclusive, meaning that only one property from each group can be True
at any one time. The properties can be accessed as if they were normal
attributes. The mutual exclusivity is automatically handled.
"""
Properties: Tuple[Tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
"ShapeCache", # Utilize the shape cache to defer computation
),
(
"Lower", # Codegen full lower function
"LowerDeclOnly", # Codegen only lower function declaration
),
(
"CanBeReused", # Codegen full reuse function
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
),
(
"CreateFn", # Codegen full create function
"CreateFnDeclOnly", # Codegen only create function declaration
),
(
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
),
)
def __init__(self, *default_properties: str):
properties: Dict[Tuple[str, ...], Optional[str]] = {
p: None for p in LazyIrProperties.Properties
}
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)
def __getattr__(self, key: str) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
return properties[values] == key
return self.__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
properties[values] = key if value else None
return value
raise KeyError(f"Invalid property: {key}")
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
@ -213,20 +281,33 @@ class LazyIrSchema:
# build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None
def __init__(self, func: FunctionSchema):
properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
"Lower",
"CanBeReused",
)
opkind: Optional[str] = None
positional_args = []
def __init__(
self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None
):
if properties:
self.properties = properties
positional_args: List[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = getattr(func.arguments, "self_arg").argument
positional_args.append(LazyArgument(arg))
positional_args.append(LazyArgument(arg, self.properties))
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
[LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
LazyArgument(arg, self.properties)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
keyword_args = []
keyword_args: List[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
@ -243,7 +324,9 @@ class LazyIrSchema:
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(arg.name, arg.type)
keyword_args.extend([LazyArgument(arg) for arg in curr_args])
keyword_args.extend(
LazyArgument(arg, self.properties) for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@ -262,7 +345,7 @@ class LazyIrSchema:
@property
def aten_name(self) -> str:
return f"{self.name.name}"
return str(self.name.name)
@property
def base_name(self) -> str:

View File

@ -1,6 +1,9 @@
from .lazy_ir import GenLazyIR as GenLazyIR
from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition
from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition
from .lazy_ir import (
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
)
from .register_dispatch_key import (
RegisterDispatchKey as RegisterDispatchKey,
gen_registration_helpers as gen_registration_helpers,

View File

@ -1,8 +1,13 @@
from abc import ABC
from typing import List, Optional, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from torchgen.context import method_with_native_function
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
from torchgen.model import (
BackendIndex,
NativeFunction,
NativeFunctionsGroup,
FunctionSchema,
)
from torchgen.api.types import (
BaseCType,
OptionalCType,
@ -12,6 +17,7 @@ from torchgen.api.types import (
)
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
LazyIrProperties,
LazyIrSchema,
LazyArgument,
getValueT,
@ -108,36 +114,44 @@ def aten_symbol(schema: LazyIrSchema) -> str:
}
if schema.aten_name in missing_interned_strings:
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
return f"at::aten::{schema.aten_name}"
if not schema.aten_name.startswith("at::"):
return f"at::aten::{schema.aten_name}"
else:
return schema.aten_name
@dataclass(frozen=True)
class GenLazyIR(ABC):
backend_index: BackendIndex
backend_name: str
node_base: str
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
return self.gen(f)
schema = LazyIrSchema(func)
return self.gen(schema)
# there is no lowering functionality generated unless this IR base class is subclassed and
# implemented as a backend-specific node
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
def lowering_function(self, schema: LazyIrSchema) -> str:
return ""
def can_be_reused_function(
self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
) -> str:
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return ""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return f"""bool CanBeReused({node_ctor_args}) const {{
return false;
}}"""
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# backends can customize the way the node base class constructor is called,
# as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = []
for arg in schema.filtered_args(values=True, scalars=False):
for arg in value_args:
if isinstance(arg.lazy_type, BaseCType) or isinstance(
arg.lazy_type, VectorCType
):
@ -151,29 +165,51 @@ class GenLazyIR(ABC):
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
scalar_args = schema.filtered_args(values=False, scalars=True)
scalar_hashes = ", ".join([f"{a.name}" for a in scalar_args])
return f"""{self.node_base}(torch::lazy::OpKind({aten_symbol(schema)}),
{{{base_ctor_value_args}}}, std::move(shapes),
# Shape constuction.
# Conditionally build shape depending on specified shape property
if schema.properties.ShapePrecompute:
shape_ctor_arg = "std::move(shapes),"
elif schema.properties.ShapeCompute:
shape_args = [a.name for a in value_args]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
elif schema.properties.ShapeCache:
shape_args = [f"operand({i})" for i in range(len(value_args))]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
else:
shape_ctor_arg = ""
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
return f"""{self.node_base}(
{schema.node_name}::ClassOpKind(),
OpList{{{base_ctor_value_args}}},
{shape_ctor_arg}
/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""
def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
def gen(self, schema: LazyIrSchema) -> List[str]:
opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
all_args = schema.filtered_args()
value_args = schema.filtered_args(values=True, scalars=False)
scalar_args = schema.filtered_args(values=False, scalars=True)
node_ctor_args = ", ".join(
[f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
)
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
reuse_ctor_args = ", ".join(ctor_args)
if schema.properties.ShapePrecompute:
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
node_ctor_args = ", ".join(ctor_args)
scalar_initializers = ",\n ".join(
[f"{a.name}({a.name})" for a in scalar_args]
f"{a.name}({a.name})" for a in scalar_args
)
comma_if_scalar_initializers = ",\n" if len(scalar_initializers) else ""
if len(scalar_initializers):
scalar_initializers = f",\n {scalar_initializers}"
scalar_decls = "\n ".join(
[
f"std::string {a.name};"
@ -212,14 +248,11 @@ class GenLazyIR(ABC):
class {schema.node_name} : public {self.node_base} {{
public:
static torch::lazy::OpKind ClassOpKind() {{
return torch::lazy::OpKind({aten_symbol(schema)});
return torch::lazy::OpKind({opkind});
}}
{schema.node_name}({node_ctor_args}, std::vector<torch::lazy::Shape>&& shapes)
: {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers}
{scalar_initializers}
{schema.node_name}({node_ctor_args})
: {self.node_base_ctor_call(schema)}{scalar_initializers}
{{
{has_optional_defs}
}}
@ -231,9 +264,11 @@ class {schema.node_name} : public {self.node_base} {{
return ss.str();
}}
{self.can_be_reused_function(f, node_ctor_args)}
{self.create_function(schema, reuse_ctor_args)}
{self.lowering_function(f)}
{self.can_be_reused_function(schema, reuse_ctor_args)}
{self.lowering_function(schema)}
{scalar_decls}
{has_optional_decls}
@ -246,37 +281,57 @@ class {schema.node_name} : public {self.node_base} {{
@dataclass(frozen=True)
class GenTSLazyIR(GenLazyIR):
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
return f"""torch::lazy::TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const override {{
{ts_lowering_body(f)}
def lowering_function(self, schema: LazyIrSchema) -> str:
signature = """
torch::lazy::TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const override"""
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif schema.properties.Lower:
return f"""{signature} {{
{ts_lowering_body(schema)}
}}
"""
else:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"static NodePtr Create({node_ctor_args})"
if schema.properties.CreateFnDeclOnly:
return f"{signature};"
elif not schema.properties.CreateFn:
return ""
return f"""{signature} {{
return ReuseOrMakeNode<{schema.node_name}>(data);
}}"""
def can_be_reused_function(
self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
) -> str:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
value_comparsion = []
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"bool CanBeReused({node_ctor_args}) const"
if schema.properties.CanBeReusedDeclOnly:
return f"{signature};"
elif not schema.properties.CanBeReused:
return ""
value_comparison = []
for arg in schema.positional_values:
if isinstance(arg.lazy_type, OptionalCType):
value_comparsion.append(
value_comparison.append(
f"operand(i++) == {arg.name}.value_or(kNullValue)"
)
else:
value_comparsion.append(f"operand(i++) == {arg.name}")
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in schema.positional_scalars:
value_comparsion.append(f"this->{arg.name} == {arg.name}")
value_comparison.append(f"this->{arg.name} == {arg.name}")
for arg in schema.keyword_values:
value_comparsion.append(f"operand(i++) == {arg.name}")
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in schema.keyword_scalars:
value_comparsion.append(f"this->{arg.name} == {arg.name}")
value_comparsion_str = " &&\n ".join(value_comparsion)
value_comparison.append(f"this->{arg.name} == {arg.name}")
value_comparison_str = " &&\n ".join(value_comparison)
return f"""bool CanBeReused({node_ctor_args}) const {{
return f"""{signature} {{
size_t i = 0;
return ({value_comparsion_str});
return ({value_comparison_str});
}}"""
@ -399,7 +454,7 @@ class GenLazyNativeFuncDefinition:
shape_str += f"""
if(torch::lazy::symbolicShapeEnabled()){{
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
char* schema_str = "{func_schema_str}";
const char* schema_str = "{func_schema_str}";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}}
"""
@ -523,3 +578,21 @@ class GenLazyShapeInferenceDefinition:
return ["\n".join([f"{shape_sig.shape_decl};"])]
else:
return []
def generate_non_native_lazy_ir_nodes(
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> List[str]:
"""Generate the non-native lazy IR node classes"""
nodes = []
for op in non_native:
# Set default properties for Non-Native IRs
properties = LazyIrProperties("ShapeCache")
for p in op.get("properties", []):
setattr(properties, p, True)
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties)
schema.opkind = op.get("opkind")
nodes.append(gen_lazy_ir.gen(schema)[0])
return nodes

View File

@ -1,15 +1,10 @@
from typing import Union
from torchgen.model import NativeFunction, NativeFunctionsGroup
from torchgen.api.lazy import LazyIrSchema
from torchgen.api.types import OptionalCType
def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
def ts_lowering_body(schema: LazyIrSchema) -> str:
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
@ -47,7 +42,7 @@ def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
"""

View File

@ -61,6 +61,7 @@ def parse_backend_yaml(
"supported",
"autograd",
"full_codegen",
"non_native",
]
backend = yaml_values.pop("backend", None)
@ -98,6 +99,9 @@ def parse_backend_yaml(
full_codegen = yaml_values.pop("full_codegen", [])
supported.extend(full_codegen)
# non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
non_native = yaml_values.pop("non_native", {})
assert (
len(yaml_values.keys()) == 0
), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \

View File

@ -5,8 +5,10 @@ import re
import yaml
from collections import namedtuple, Counter
from typing import (
Any,
List,
Dict,
Tuple,
Union,
Sequence,
Optional,
@ -106,10 +108,10 @@ ParsedExternalYaml = namedtuple(
)
def parse_full_codegen_ops(
def parse_native_functions_keys(
backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
) -> List[OperatorName]:
) -> Tuple[List[OperatorName], List[Any]]:
native_functions_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
@ -124,12 +126,10 @@ def parse_full_codegen_ops(
assert isinstance(yaml_values, dict)
full_codegen = yaml_values.pop("full_codegen", [])
assert isinstance(
full_codegen, list
), f'expected "full_codegen" to be a list, but got: {full_codegen}'
full_codegen = [OperatorName.parse(name) for name in full_codegen]
return full_codegen
non_native = yaml_values.pop("non_native", [])
assert isinstance(full_codegen, list)
assert isinstance(non_native, list)
return [OperatorName.parse(name) for name in full_codegen], non_native
def validate_shape_inference_header(
@ -150,13 +150,16 @@ def validate_shape_inference_header(
)
# TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
for decl in expected_shape_infr_decls:
assert (
decl in shape_infr_decl_lines
), f"""Missing shape inference function.\n
missing_decls = [
decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
]
if missing_decls:
raise Exception(
f"""Missing shape inference function.\n
Please add declare this function in {shape_inference_hdr}:\n
and implement it in the the corresponding shape_inference.cpp file.\n
{decl}"""
{os.linesep.join(missing_decls)}"""
)
class default_args:
@ -324,7 +327,9 @@ def run_gen_lazy_tensor(
autograd_key = parsed_backend_yaml.autograd_key
cpp_namespace = parsed_backend_yaml.cpp_namespace
backend_indices = parsed_backend_yaml.backend_indices
full_codegen = parse_full_codegen_ops(source_yaml, grouped_native_functions)
full_codegen, non_native = parse_native_functions_keys(
source_yaml, grouped_native_functions
)
def concat_map_codegen(
func: Callable[[NativeFunction], Sequence[str]],
@ -476,6 +481,10 @@ def run_gen_lazy_tensor(
},
)
# Generate IR node classes
lazy_ir_obj = lazy_ir_generator(
backend_indices[backend_key], backend_name, node_base
)
fm.write_with_template(
"LazyIr.h",
"LazyIr.h",
@ -492,16 +501,35 @@ def run_gen_lazy_tensor(
"vector",
]
],
"lazy_ir_inc": [
f'#include "{path}"'
for path in [node_base_hdr if node_base_hdr is not None else None]
if path is not None
],
"lazy_ir_inc": [f'#include "{node_base_hdr}"']
if node_base_hdr is not None
else [],
"ir_declarations": list(
concat_map_codegen(
lazy_ir_generator(backend_indices[backend_key], node_base),
grouped_native_functions,
)
concat_map_codegen(lazy_ir_obj, grouped_native_functions)
),
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,
},
)
# Generate Non Native IR Node classes
fm.write_with_template(
"LazyNonNativeIr.h",
"LazyNonNativeIr.h",
lambda: {
"lazy_non_native_ir_inc": [
f"#include <{path}>"
for path in [
"torch/csrc/lazy/core/ir.h",
"torch/csrc/lazy/core/ir_builder.h",
"torch/csrc/lazy/core/internal_ops/ltc_ops.h",
"torch/csrc/lazy/core/shape_inference.h",
]
+ ([node_base_hdr] if node_base_hdr else [])
if path
],
"non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
non_native, lazy_ir_obj
),
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,

View File

@ -1072,18 +1072,14 @@ class FunctionSchema:
self.arguments.out,
)
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
@staticmethod
def parse(func: str) -> "FunctionSchema":
# We should probably get a proper parser here
assert (
" -> " in func
), "function schema missing return type (spaces are mandatory)"
last_index = func.rfind(" -> ")
func_decl = func[:last_index]
return_decl = func[last_index + len(" -> ") :]
ops, args = func_decl.split("(", 1)
assert args[-1] == ")", "Expecting closing )"
args = args[:-1]
decls = FunctionSchema.decl_re.findall(func)
assert len(decls) == 1, f"Invalid function schema: {func}"
ops, args, return_decl = decls[0]
name = OperatorName.parse(ops)
arguments = Arguments.parse(args)
returns = parse_returns(return_decl)