mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Codegen Non-Native IR Nodes (#76535)
Add codegen infrastructure to generate IR nodes for non-native ops. The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g. ``` non_native: ... - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor ... ``` these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`. Fixes #74628 CC: @wconstab @desertfire @henrytwo Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
13dcba8c07
commit
02c4d877b4
@ -1865,6 +1865,7 @@ test_suite(
|
||||
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
|
||||
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
|
||||
"aten/src/ATen/templates/LazyIr.h",
|
||||
"aten/src/ATen/templates/LazyNonNativeIr.h",
|
||||
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
|
@ -178,3 +178,41 @@ supported:
|
||||
- _unsafe_view
|
||||
autograd:
|
||||
- max_pool3d
|
||||
|
||||
# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
|
||||
non_native:
|
||||
- func: scalar(Scalar value, ScalarType type) -> Tensor
|
||||
opkind: at::prim::Constant
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- TreatScalarsAsConstants
|
||||
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
|
||||
- func: view(Tensor input, int[] output_size) -> Tensor
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
|
||||
opkind: ltc_cast
|
||||
properties:
|
||||
- ShapeCompute
|
||||
|
||||
# View ops only required until proper functionalization pass is introduced into LTC
|
||||
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
|
||||
opkind: ltc_as_strided_view_update
|
||||
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
|
||||
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
|
||||
opkind: ltc_diagonal_view_update
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
|
||||
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
|
||||
opkind: ltc_narrow_view_update
|
||||
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
|
||||
- func: permute(Tensor input, int[] dims) -> Tensor
|
||||
- func: resize(Tensor input, int[] size) -> Tensor
|
||||
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
|
||||
opkind: ltc_select_view_update
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
|
||||
- func: squeeze(Tensor input, int dim) -> Tensor
|
||||
- func: unsqueeze(Tensor input, int dim) -> Tensor
|
||||
|
11
aten/src/ATen/templates/LazyNonNativeIr.h
Normal file
11
aten/src/ATen/templates/LazyNonNativeIr.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
${lazy_non_native_ir_inc}
|
||||
|
||||
// This file contains autogenerated LazyTensor Non Native IR nodes
|
||||
|
||||
${namespace_prologue}
|
||||
|
||||
${non_native_ir_nodes}
|
||||
|
||||
${namespace_epilogue}
|
@ -28,6 +28,7 @@ def define_targets(rules):
|
||||
":DispatchKeyNativeFunctions.cpp",
|
||||
":DispatchKeyNativeFunctions.h",
|
||||
":LazyIr.h",
|
||||
":LazyNonNativeIr.h",
|
||||
":RegisterDispatchKey.cpp",
|
||||
":native_functions.yaml",
|
||||
":shape_inference.h",
|
||||
@ -88,6 +89,7 @@ GENERATED_TESTING_PY = [
|
||||
|
||||
GENERATED_LAZY_H = [
|
||||
"torch/csrc/lazy/generated/LazyIr.h",
|
||||
"torch/csrc/lazy/generated/LazyNonNativeIr.h",
|
||||
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
|
||||
]
|
||||
|
||||
|
@ -380,6 +380,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
list(APPEND GENERATED_H_TORCH
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.h"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNonNativeIr.h"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.h"
|
||||
)
|
||||
endif()
|
||||
@ -444,6 +445,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.h"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"
|
||||
|
@ -417,35 +417,19 @@ lazy_tensor_core_sources = [
|
||||
# We can't build all of the ts backend under certain build configurations, e.g. mobile,
|
||||
# since it depends on things like autograd, meta functions, which may be disabled
|
||||
lazy_tensor_ts_sources = [
|
||||
"torch/csrc/lazy/ts_backend/config.cpp",
|
||||
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
|
||||
"torch/csrc/lazy/ts_backend/config.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/cast.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/narrow.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/permute.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/resize.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/select.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/view.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_node.cpp",
|
||||
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_backend_impl.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_lowering_context.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_node.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_node_lowering.cpp",
|
||||
]
|
||||
|
||||
|
@ -237,7 +237,7 @@ invalid_key: invalid_val"""
|
||||
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
|
||||
self.assertExpectedInline(
|
||||
output_error,
|
||||
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""", # noqa: B950
|
||||
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950
|
||||
)
|
||||
|
||||
# if use_out_as_primary is provided, it must be a bool
|
||||
|
@ -19,6 +19,10 @@ hash_t Output::hash() const {
|
||||
return HashCombine(node->hash(), Hash(index));
|
||||
}
|
||||
|
||||
hash_t Output::shapeHash() const {
|
||||
return HashCombine(node->shapeHash(), Hash(index));
|
||||
}
|
||||
|
||||
std::string Output::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << node->ToString() << ", index=" << index;
|
||||
@ -144,7 +148,7 @@ std::string Node::ToString() const {
|
||||
|
||||
void Node::AddOperand(NodePtr node, size_t index) {
|
||||
CHECK_LT(index, node->num_outputs());
|
||||
operands_.push_back(std::move(node));
|
||||
operands_.push_back(node);
|
||||
operands_as_outputs_.emplace_back(operands_.back().get(), index);
|
||||
}
|
||||
|
||||
|
@ -214,6 +214,7 @@ struct TORCH_API Output {
|
||||
: node(node), index(index) {}
|
||||
|
||||
hash_t hash() const;
|
||||
hash_t shapeHash() const;
|
||||
|
||||
bool operator==(const Output& rhs) const {
|
||||
return node == rhs.node && index == rhs.index;
|
||||
|
@ -6,33 +6,33 @@
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
|
||||
TORCH_API bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
|
||||
|
||||
std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
|
||||
TORCH_API std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
|
||||
|
||||
Shape MakeDiagonalShape(
|
||||
TORCH_API Shape MakeDiagonalShape(
|
||||
const Shape& shape,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2);
|
||||
|
||||
Shape MakePermuteShape(
|
||||
TORCH_API Shape MakePermuteShape(
|
||||
const Shape& source_shape,
|
||||
c10::ArrayRef<int64_t> permutation);
|
||||
|
||||
Shape MakeSelectShape(
|
||||
TORCH_API Shape MakeSelectShape(
|
||||
const Shape& shape,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t stride);
|
||||
|
||||
int64_t GetStride(int64_t start, int64_t end, int64_t stride);
|
||||
TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride);
|
||||
|
||||
std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
|
||||
TORCH_API std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
|
||||
int64_t squeeze_dim);
|
||||
|
||||
std::vector<int64_t> BuildUnsqueezedDimensions(
|
||||
TORCH_API std::vector<int64_t> BuildUnsqueezedDimensions(
|
||||
c10::ArrayRef<int64_t> dimensions,
|
||||
int64_t squeeze_dim);
|
||||
|
||||
|
@ -23,8 +23,11 @@ std::vector<typename Container::value_type> PermuteDimensions(
|
||||
const Container& dimensions) {
|
||||
using T = typename Container::value_type;
|
||||
TORCH_CHECK(
|
||||
dimensions.size() == permutation.size() && IsPermutation(permutation),
|
||||
"Invalid permutation specified");
|
||||
dimensions.size() == permutation.size(),
|
||||
"Invalid permutation specified. dimensions.size() != permutation.size() (", dimensions.size(), " vs. ", permutation.size(), ")");
|
||||
TORCH_CHECK(
|
||||
IsPermutation(permutation),
|
||||
"Invalid permutation specified. Permutation is not permutation");
|
||||
std::vector<T> output(dimensions.size());
|
||||
for (const auto i : c10::irange(permutation.size())) {
|
||||
output[i] = dimensions[permutation[i]];
|
||||
|
@ -45,10 +45,12 @@
|
||||
|
||||
#include <torch/csrc/lazy/core/shape_inference.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <aten/src/ATen/native/ReduceOpsUtils.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -629,6 +631,67 @@ std::vector<Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t di
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
|
||||
// Non-Native Ops
|
||||
std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type) {
|
||||
return { Shape(type, {}) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_expand(const Output& input, const std::vector<int64_t>& size, const bool& is_scalar_expand) {
|
||||
return { Shape(input.shape().scalar_type(), size) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_view(const Output& input, const std::vector<int64_t>& output_sizes) {
|
||||
const Shape& input_shape = input.shape();
|
||||
const auto complete_output_sizes =
|
||||
at::infer_size(output_sizes, input_shape.numel());
|
||||
return { Shape(input_shape.scalar_type(), complete_output_sizes) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_cast(const Output& input, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype) {
|
||||
Shape shape = input.shape();
|
||||
shape.set_scalar_type(dtype);
|
||||
return { shape };
|
||||
}
|
||||
|
||||
|
||||
// View Ops
|
||||
std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
|
||||
return { Shape(target.shape().scalar_type(), size) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
|
||||
return { Shape(input.shape().scalar_type(), size) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
|
||||
return { target.shape() };
|
||||
}
|
||||
std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
|
||||
return { MakeDiagonalShape(input.shape(), offset, dim1, dim2) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices) {
|
||||
return { input.shape() };
|
||||
}
|
||||
std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes) {
|
||||
return { Shape(input.shape().scalar_type(), sizes) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims) {
|
||||
return { MakePermuteShape(input.shape(), dims) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size) {
|
||||
return { Shape(input.shape().scalar_type(), size) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
|
||||
return { target.shape() };
|
||||
}
|
||||
std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
|
||||
return { MakeSelectShape(input.shape(), dim, start, end, stride) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
|
||||
const auto& input_shape = input.shape();
|
||||
return { torch::lazy::Shape(input_shape.scalar_type(), BuildSqueezedDimensions(input_shape.sizes(), dim)) };
|
||||
}
|
||||
std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim) {
|
||||
const auto& input_shape = input.shape();
|
||||
return { torch::lazy::Shape(input_shape.scalar_type(), BuildUnsqueezedDimensions(input_shape.sizes(), dim)) };
|
||||
}
|
||||
|
||||
// Restore unused-parameters warnings
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/core/ir.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
#include <vector>
|
||||
@ -68,5 +69,26 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tenso
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
|
||||
|
||||
// Non-Native ops
|
||||
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
|
||||
TORCH_API std::vector<Shape> compute_shape_expand(const Output& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand);
|
||||
TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std::vector<int64_t>& output_sizes);
|
||||
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);
|
||||
|
||||
// View Ops
|
||||
TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
|
||||
TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
|
||||
TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
|
||||
TORCH_API std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
|
||||
TORCH_API std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices);
|
||||
TORCH_API std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes);
|
||||
TORCH_API std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims);
|
||||
TORCH_API std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size);
|
||||
TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
|
||||
TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
|
||||
TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
|
||||
TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -4,31 +4,11 @@
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/shape_inference.h>
|
||||
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/narrow.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/select_view_update.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/permute.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/resize.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/squeeze.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/view.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
|
||||
|
||||
// This file contains the TorchScript IrBuilder
|
||||
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
@ -1,42 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
|
||||
Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
|
||||
Shape shape = input.shape();
|
||||
shape.set_scalar_type(type);
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Cast::Cast(
|
||||
const Value& input,
|
||||
at::ScalarType dtype,
|
||||
c10::optional<at::ScalarType> stype)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
{NodeOutputShape(input, dtype)},
|
||||
/*num_outputs=*/1,
|
||||
MHash(101, static_cast<int>(dtype), OptionalOr<int>(stype, -1))),
|
||||
dtype_(dtype),
|
||||
stype_(stype) {}
|
||||
|
||||
std::string Cast::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString();
|
||||
ss << ", dtype=" << dtype_;
|
||||
if (stype_) {
|
||||
ss << ", stype=" << *stype_;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,38 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Cast : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_cast;
|
||||
}
|
||||
|
||||
Cast(
|
||||
const Value& input,
|
||||
at::ScalarType dtype,
|
||||
c10::optional<at::ScalarType> stype = c10::nullopt);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
at::ScalarType dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
const c10::optional<at::ScalarType>& stype() const {
|
||||
return stype_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::ScalarType dtype_;
|
||||
c10::optional<at::ScalarType> stype_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,29 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Expand::Expand(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
bool is_scalar_expand)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
/*num_outputs=*/1,
|
||||
MHash(size, is_scalar_expand)),
|
||||
size_(std::move(size)),
|
||||
is_scalar_expand_(is_scalar_expand) {
|
||||
addComputedShape(
|
||||
[&]() { return Shape(input.shape().scalar_type(), size_); });
|
||||
}
|
||||
|
||||
std::string Expand::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
|
||||
<< "), is_scalar_expand=" << is_scalar_expand_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,37 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Expand : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::expand);
|
||||
}
|
||||
|
||||
Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
bool is_scalar_expand() const {
|
||||
return is_scalar_expand_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_;
|
||||
// True iff the input was a scalar and this was generated internally by a
|
||||
// lowering and not by user action. For some backends, this difference can be
|
||||
// material (for example setting strides according to eager semantics).
|
||||
bool is_scalar_expand_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,40 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
|
||||
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
|
||||
#include <ATen/core/Formatting.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
using at::operator<<;
|
||||
|
||||
Scalar::Scalar(const at::Scalar& value, Shape shape)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
std::move(shape),
|
||||
/*num_outputs=*/1,
|
||||
ScalarHash(value)),
|
||||
value_(value) {}
|
||||
|
||||
Scalar::Scalar(const at::Scalar& value, c10::ScalarType type)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{Shape(type, {})},
|
||||
/*num_outputs=*/1,
|
||||
ScalarHash(value)),
|
||||
value_(value) {}
|
||||
|
||||
std::string Scalar::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", value=" << value_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
hash_t ScalarHash(const at::Scalar& s) {
|
||||
return s.isFloatingPoint() ? Hash(s.toDouble()) : Hash(s.toLong());
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,35 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// Differently from Constant, this is a scalar value broadcasted to a shape.
|
||||
// Even though a Constant could have been used, for simple scalars broadcasted
|
||||
// to big shapes, the Constant leads to big literals expanded within the
|
||||
// computation graph.
|
||||
class TORCH_API Scalar : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::prim::Constant);
|
||||
}
|
||||
|
||||
Scalar(const at::Scalar& value, Shape shape);
|
||||
Scalar(const at::Scalar& value, c10::ScalarType type);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const at::Scalar& value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Scalar value_;
|
||||
};
|
||||
|
||||
TORCH_API hash_t ScalarHash(const at::Scalar& s);
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -17,29 +17,29 @@ namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
|
||||
hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes) {
|
||||
hash_t OperandHashes(const OpList& operands,
|
||||
const c10::ArrayRef<Shape>& shapes,
|
||||
const hash_t& seed, bool bakeInSizes) {
|
||||
hash_t hash = seed;
|
||||
for (auto& operand : operands) {
|
||||
if (!operand) {
|
||||
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
||||
continue;
|
||||
}
|
||||
auto operand_hash = operand.hash();
|
||||
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
|
||||
hash = HashCombine(hash, operand_hash);
|
||||
}
|
||||
for (auto& shape : shapes) {
|
||||
hash = HashCombine(hash, shape.hash(bakeInSizes));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes) {
|
||||
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
|
||||
return HashCombine(h, hash_seed);
|
||||
}
|
||||
|
||||
TsNode::TsNode(OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs, hash_t hash_seed)
|
||||
: Node(op, operands, std::move(shapes), num_outputs) {
|
||||
hash_seed = HashCombine(op.hash(), hash_seed);
|
||||
shape_hash_ = OperandHashes(operands, hash_seed, true);
|
||||
dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, hash_seed, false) : shape_hash_);
|
||||
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
|
||||
dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, this->shapes(), hash_seed, false) : shape_hash_);
|
||||
}
|
||||
|
||||
|
||||
@ -53,11 +53,7 @@ TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
|
||||
: TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
||||
|
||||
TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
||||
: Node(op, num_outputs),
|
||||
shape_hash_(GetOpHash(op, shape, hash_seed, true)),
|
||||
dag_hash_(enableDynamicShape() ? GetOpHash(op, shape, hash_seed, false) : shape_hash_) {
|
||||
shapes_.push_back(std::move(shape));
|
||||
}
|
||||
: TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
|
||||
|
||||
hash_t TsNode::hash() const { return dag_hash_; }
|
||||
|
||||
@ -81,8 +77,8 @@ TensorList::TensorList(OpList values)
|
||||
: TsNode(/*op=*/ClassOpKind(),
|
||||
/*operands=*/values,
|
||||
/*shapes=*/std::vector<Shape>(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/OperandHashes(values, /*seed=*/kHashSeed, enableDynamicShape())) {}
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/kHashSeed) {}
|
||||
|
||||
TSOpVector TensorList::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
TSLoweringContext* loctx) const {
|
||||
@ -97,5 +93,7 @@ TSOpVector TensorList::Lower(std::shared_ptr<torch::jit::GraphFunction> function
|
||||
return {listnode->output()};
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -146,9 +146,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
arguments.emplace_back(node->stride());
|
||||
arguments.emplace_back(node->storage_offset());
|
||||
arguments.emplace_back(node->size);
|
||||
arguments.emplace_back(node->stride);
|
||||
arguments.emplace_back(node->storage_offset);
|
||||
TSOpVector as_strided_out = LowerBuiltin(node, arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
return {GenerateClone(as_strided_out.front())};
|
||||
@ -165,8 +165,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
dest_arguments.emplace_back(destination);
|
||||
dest_arguments.emplace_back(
|
||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||
dest_arguments.emplace_back(node->stride());
|
||||
dest_arguments.emplace_back(node->storage_offset());
|
||||
dest_arguments.emplace_back(node->stride);
|
||||
dest_arguments.emplace_back(node->storage_offset
|
||||
);
|
||||
TSOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
@ -209,16 +210,16 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerCast(const torch::lazy::Cast* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dtype());
|
||||
arguments.emplace_back(node->dtype);
|
||||
return LowerBuiltin(at::aten::to, arguments);
|
||||
}
|
||||
|
||||
TSOpVector LowerExpand(const torch::lazy::Expand* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
arguments.emplace_back(node->size);
|
||||
auto expand_out = LowerBuiltin(node, arguments);
|
||||
if (node->is_scalar_expand()) {
|
||||
if (node->is_scalar_expand) {
|
||||
// The aten::expand operations sets all strides to 0 when the original is
|
||||
// of rank 0. This leads to false positives when checking for internal
|
||||
// memory overlap, because at::has_internal_overlap returns
|
||||
@ -232,8 +233,8 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerNarrow(const torch::lazy::Narrow* node) {
|
||||
const torch::lazy::Output& input = node->operand(0);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(input);
|
||||
const auto& base_indices = node->base_indices();
|
||||
const auto& sizes = node->sizes();
|
||||
const auto& base_indices = node->base_indices;
|
||||
const auto& sizes = node->sizes;
|
||||
const torch::lazy::Shape& input_shape = input.shape();
|
||||
CHECK_EQ(sizes.size(), base_indices.size());
|
||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||
@ -248,12 +249,12 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerPermute(const torch::lazy::Permute* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dims());
|
||||
arguments.emplace_back(node->dims);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TSOpVector LowerScalar(const torch::lazy::Scalar* node) {
|
||||
const at::Scalar& value = node->value();
|
||||
const at::Scalar& value = node->value;
|
||||
const torch::lazy::Shape& shape = node->shape();
|
||||
auto options =
|
||||
at::TensorOptions()
|
||||
@ -264,19 +265,19 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
}
|
||||
|
||||
TSOpVector LowerSelect(const torch::lazy::Select* node) {
|
||||
int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(),
|
||||
node->stride());
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end,
|
||||
node->stride);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
|
||||
return {GenerateSlice(/*base=*/base, /*dim=*/node->dim(),
|
||||
/*start=*/node->start(), /*end=*/node->end(),
|
||||
return {GenerateSlice(/*base=*/base, /*dim=*/node->dim,
|
||||
/*start=*/node->start, /*end=*/node->end,
|
||||
/*step=*/step)};
|
||||
}
|
||||
|
||||
TSOpVector LowerSqueeze(const Squeeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
if (node->dim() != -1) {
|
||||
arguments.emplace_back(node->dim());
|
||||
if (node->dim != -1) {
|
||||
arguments.emplace_back(node->dim);
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
@ -284,11 +285,11 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(),
|
||||
node->stride());
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end,
|
||||
node->stride);
|
||||
torch::jit::Value* selected = GenerateSlice(
|
||||
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
|
||||
/*end=*/node->end(), /*step=*/step);
|
||||
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
|
||||
/*end=*/node->end, /*step=*/step);
|
||||
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
|
||||
return {dest};
|
||||
}
|
||||
@ -296,7 +297,7 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const auto& base_indices = node->base_indices();
|
||||
const auto& base_indices = node->base_indices;
|
||||
const torch::lazy::Output& source_argument = node->operand(1);
|
||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||
@ -314,23 +315,23 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
TSOpVector LowerUnsqueeze(const Unsqueeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dim());
|
||||
arguments.emplace_back(node->dim);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TSOpVector LowerView(const torch::lazy::View* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->output_size());
|
||||
arguments.emplace_back(node->output_size);
|
||||
return LowerBuiltin(at::aten::reshape, arguments);
|
||||
}
|
||||
|
||||
TSOpVector LowerDiagonal(const Diagonal* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->offset());
|
||||
arguments.emplace_back(node->dim1());
|
||||
arguments.emplace_back(node->dim2());
|
||||
arguments.emplace_back(node->offset);
|
||||
arguments.emplace_back(node->dim1);
|
||||
arguments.emplace_back(node->dim2);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
@ -346,9 +347,9 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
||||
// Replay the diagonal.
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(node->offset());
|
||||
arguments.emplace_back(node->dim1());
|
||||
arguments.emplace_back(node->dim2());
|
||||
arguments.emplace_back(node->offset);
|
||||
arguments.emplace_back(node->dim1);
|
||||
arguments.emplace_back(node->dim2);
|
||||
auto diag = LowerBuiltin(at::aten::diagonal, arguments);
|
||||
|
||||
// Update the replayed diagonal view with the input.
|
||||
|
@ -1,52 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/core/util.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
AsStrided::AsStrided(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
std::vector<int64_t> stride,
|
||||
int64_t storage_offset)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
[&]() {
|
||||
return Shape(input.shape().scalar_type(), size);
|
||||
},
|
||||
/*num_outputs=*/1,
|
||||
MHash(size, stride, storage_offset)),
|
||||
size_(std::move(size)),
|
||||
stride_(std::move(stride)),
|
||||
storage_offset_(storage_offset) {}
|
||||
|
||||
std::string AsStrided::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
|
||||
<< "), stride=(" << c10::Join(", ", stride_)
|
||||
<< "), storage_offset=" << storage_offset_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool AsStrided::StrideIsSupported(c10::ArrayRef<int64_t> stride) {
|
||||
std::vector<int64_t> sorted_stride(stride.begin(), stride.end());
|
||||
std::sort(sorted_stride.begin(), sorted_stride.end());
|
||||
return stride.empty() || sorted_stride.front() == 1;
|
||||
}
|
||||
|
||||
std::vector<int64_t> AsStrided::GetArrayStridePermutation(
|
||||
c10::ArrayRef<int64_t> stride) {
|
||||
std::vector<int64_t> permutation = Iota<int64_t>(stride.size());
|
||||
std::sort(permutation.begin(), permutation.end(), [&](int64_t a, int64_t b) {
|
||||
return stride[a] > stride[b];
|
||||
});
|
||||
return permutation;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,48 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API AsStrided : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::as_strided);
|
||||
}
|
||||
|
||||
AsStrided(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
std::vector<int64_t> stride,
|
||||
int64_t storage_offset);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
int64_t storage_offset() const {
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
static bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
|
||||
|
||||
static std::vector<int64_t> GetArrayStridePermutation(
|
||||
c10::ArrayRef<int64_t> stride);
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_;
|
||||
std::vector<int64_t> stride_;
|
||||
int64_t storage_offset_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,37 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
AsStridedViewUpdate::AsStridedViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
std::vector<int64_t> stride,
|
||||
int64_t storage_offset)
|
||||
: TsNode(
|
||||
ltc_as_strided_view_update,
|
||||
{target, input},
|
||||
[&]() {
|
||||
return Shape(target.shape().scalar_type(), size);
|
||||
},
|
||||
/*num_outputs=*/1,
|
||||
MHash(size, stride, storage_offset)),
|
||||
size_(std::move(size)),
|
||||
stride_(std::move(stride)),
|
||||
storage_offset_(storage_offset) {}
|
||||
|
||||
std::string AsStridedViewUpdate::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_)
|
||||
<< "), stride=(" << c10::Join(", ", stride_)
|
||||
<< "), storage_offset=" << storage_offset_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,45 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
#include <vector>
|
||||
#include <lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API AsStridedViewUpdate : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_as_strided_view_update;
|
||||
}
|
||||
|
||||
AsStridedViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
std::vector<int64_t> stride,
|
||||
int64_t storage_offset);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
int64_t storage_offset() const {
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_;
|
||||
std::vector<int64_t> stride_;
|
||||
int64_t storage_offset_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,35 +0,0 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Diagonal::Diagonal(
|
||||
const Value& input,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
[&]() {
|
||||
return MakeDiagonalShape(input.shape(), offset, dim1, dim2);
|
||||
},
|
||||
/*num_outputs=*/1,
|
||||
MHash(offset, dim1, dim2)),
|
||||
offset_(offset),
|
||||
dim1_(dim1),
|
||||
dim2_(dim2) {}
|
||||
|
||||
std::string Diagonal::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_
|
||||
<< ", dim2=" << dim2_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,37 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Diagonal : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::diagonal);
|
||||
}
|
||||
|
||||
Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int64_t offset() const {
|
||||
return offset_;
|
||||
}
|
||||
|
||||
int64_t dim1() const {
|
||||
return dim1_;
|
||||
}
|
||||
|
||||
int64_t dim2() const {
|
||||
return dim2_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t offset_;
|
||||
int64_t dim1_;
|
||||
int64_t dim2_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,32 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
DiagonalViewUpdate::DiagonalViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2)
|
||||
: TsNode(
|
||||
ltc_diagonal_view_update,
|
||||
{target, input},
|
||||
{target.shape()},
|
||||
/*num_outputs=*/1,
|
||||
MHash(offset, dim1, dim2)),
|
||||
offset_(offset),
|
||||
dim1_(dim1),
|
||||
dim2_(dim2) {}
|
||||
|
||||
std::string DiagonalViewUpdate::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_
|
||||
<< ", dim2=" << dim2_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,43 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
#include <lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API DiagonalViewUpdate : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_diagonal_view_update;
|
||||
}
|
||||
|
||||
DiagonalViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
int64_t offset,
|
||||
int64_t dim1,
|
||||
int64_t dim2);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int64_t offset() const {
|
||||
return offset_;
|
||||
}
|
||||
|
||||
int64_t dim1() const {
|
||||
return dim1_;
|
||||
}
|
||||
|
||||
int64_t dim2() const {
|
||||
return dim2_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t offset_;
|
||||
int64_t dim1_;
|
||||
int64_t dim2_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,33 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/narrow.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Narrow::Narrow(
|
||||
const Value& input,
|
||||
c10::ArrayRef<int64_t> base_indices,
|
||||
c10::ArrayRef<int64_t> sizes)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
/*num_outputs=*/1,
|
||||
MHash(base_indices, sizes)),
|
||||
base_indices_(base_indices.begin(), base_indices.end()),
|
||||
sizes_(sizes.begin(), sizes.end()) {
|
||||
addComputedShape([&]() {
|
||||
return Shape(operand(0).shape().scalar_type(), sizes);
|
||||
});
|
||||
}
|
||||
|
||||
std::string Narrow::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", base_indices=("
|
||||
<< c10::Join(", ", base_indices_) << "), sizes=("
|
||||
<< c10::Join(", ", sizes_) << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,35 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Narrow : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::narrow);
|
||||
}
|
||||
|
||||
Narrow(
|
||||
const Value& input,
|
||||
c10::ArrayRef<int64_t> base_indices,
|
||||
c10::ArrayRef<int64_t> sizes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& base_indices() const {
|
||||
return base_indices_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& sizes() const {
|
||||
return sizes_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> base_indices_;
|
||||
std::vector<int64_t> sizes_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,29 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
NarrowViewUpdate::NarrowViewUpdate(
|
||||
const Value& input,
|
||||
const Value& source,
|
||||
c10::ArrayRef<int64_t> base_indices)
|
||||
: TsNode(
|
||||
ltc_narrow_view_update,
|
||||
{input, source},
|
||||
/*num_outputs=*/1,
|
||||
MHash(base_indices)),
|
||||
base_indices_(base_indices.begin(), base_indices.end()) {
|
||||
addComputedShape([&]() { return operand(0).shape(); });
|
||||
}
|
||||
|
||||
std::string NarrowViewUpdate::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", base_indices=("
|
||||
<< c10::Join(", ", base_indices_) << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,31 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
#include <lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API NarrowViewUpdate : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_narrow_view_update;
|
||||
}
|
||||
|
||||
NarrowViewUpdate(
|
||||
const Value& input,
|
||||
const Value& source,
|
||||
c10::ArrayRef<int64_t> base_indices);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& base_indices() const {
|
||||
return base_indices_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> base_indices_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,28 +0,0 @@
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/permute.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/helpers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Permute::Permute(const Value& input, std::vector<int64_t> dims)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
/*num_outputs=*/1,
|
||||
MHash(dims)),
|
||||
dims_(std::move(dims)) {
|
||||
addComputedShape([&]() {
|
||||
return MakePermuteShape(operand(0).shape(), dims_);
|
||||
});
|
||||
}
|
||||
|
||||
std::string Permute::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", dims=(" << c10::Join(", ", dims_) << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,28 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Permute : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::permute);
|
||||
}
|
||||
|
||||
Permute(const Value& input, std::vector<int64_t> dims);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& dims() const {
|
||||
return dims_;
|
||||
}
|
||||
|
||||
private:
|
||||
// The permutation of dimensions.
|
||||
std::vector<int64_t> dims_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,29 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/resize.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> size) {
|
||||
return Shape(input.shape().scalar_type(), size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Resize::Resize(const Value& input, std::vector<int64_t> size)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
[&]() { return NodeOutputShape(input, size); },
|
||||
/*num_outputs=*/1,
|
||||
MHash(size)),
|
||||
size_(std::move(size)) {}
|
||||
|
||||
std::string Resize::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Resize : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::resize);
|
||||
}
|
||||
|
||||
Resize(const Value& input, std::vector<int64_t> size);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,44 +0,0 @@
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Select::Select(
|
||||
const Value& input,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t stride)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
[&]() {
|
||||
return MakeSelectShape(input.shape(), dim, start, end, stride);
|
||||
},
|
||||
/*num_outputs=*/1,
|
||||
MHash(dim, start, end, stride)),
|
||||
dim_(dim),
|
||||
start_(start),
|
||||
end_(end),
|
||||
stride_(stride) {}
|
||||
|
||||
std::string Select::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_
|
||||
<< ", end=" << end_ << ", stride=" << stride_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
int64_t Select::GetStride(int64_t start, int64_t end, int64_t stride) {
|
||||
if (stride == 0) {
|
||||
CHECK_EQ(start, end);
|
||||
stride = 1;
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,49 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Select : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::select);
|
||||
}
|
||||
|
||||
Select(
|
||||
const Value& input,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t stride);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int64_t dim() const {
|
||||
return dim_;
|
||||
}
|
||||
|
||||
int64_t start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
int64_t end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
int64_t stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
static int64_t GetStride(int64_t start, int64_t end, int64_t stride);
|
||||
|
||||
private:
|
||||
int64_t dim_;
|
||||
int64_t start_;
|
||||
int64_t end_;
|
||||
int64_t stride_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,36 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/select_view_update.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
SelectViewUpdate::SelectViewUpdate(
|
||||
const Value& target,
|
||||
const Value& source,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t stride)
|
||||
: TsNode(
|
||||
ltc_select_view_update,
|
||||
{target, source},
|
||||
{target.shape()},
|
||||
/*num_outputs=*/1,
|
||||
MHash(dim, start, end, stride)),
|
||||
dim_(dim),
|
||||
start_(start),
|
||||
end_(end),
|
||||
stride_(stride) {}
|
||||
|
||||
std::string SelectViewUpdate::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_
|
||||
<< ", end=" << end_ << ", stride=" << stride_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,49 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
#include <lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API SelectViewUpdate : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_select_view_update;
|
||||
}
|
||||
|
||||
SelectViewUpdate(
|
||||
const Value& target,
|
||||
const Value& source,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end,
|
||||
int64_t stride);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int64_t dim() const {
|
||||
return dim_;
|
||||
}
|
||||
|
||||
int64_t start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
int64_t end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
int64_t stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t dim_;
|
||||
int64_t start_;
|
||||
int64_t end_;
|
||||
int64_t stride_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,28 +0,0 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/squeeze.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Squeeze::Squeeze(const torch::lazy::Value& input, int dim)
|
||||
: torch::lazy::TsNode(ClassOpKind(), {input},
|
||||
/*num_outputs=*/1, torch::lazy::MHash(dim)),
|
||||
dim_(dim) {
|
||||
addComputedShape(
|
||||
[&]() {
|
||||
const auto& input_shape = input.shape();
|
||||
return torch::lazy::Shape(input_shape.scalar_type(),
|
||||
BuildSqueezedDimensions(input_shape.sizes(), dim));
|
||||
});
|
||||
}
|
||||
|
||||
std::string Squeeze::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,26 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Squeeze : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::squeeze);
|
||||
}
|
||||
|
||||
// Squeeze out the specified dimension index, -1 for all trivial dimensions.
|
||||
Squeeze(const torch::lazy::Value& input, int dim);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int dim() const { return dim_; }
|
||||
|
||||
private:
|
||||
int dim_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,30 +0,0 @@
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim)
|
||||
: torch::lazy::TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
/*num_outputs=*/1,
|
||||
torch::lazy::MHash(dim)),
|
||||
dim_(dim) {
|
||||
addComputedShape([&]() {
|
||||
const auto& input_shape = input.shape();
|
||||
return torch::lazy::Shape(
|
||||
input_shape.scalar_type(),
|
||||
BuildUnsqueezedDimensions(input_shape.sizes(), dim));
|
||||
});
|
||||
}
|
||||
|
||||
std::string Unsqueeze::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API Unsqueeze : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::unsqueeze);
|
||||
}
|
||||
|
||||
Unsqueeze(const torch::lazy::Value& input, int dim);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
int dim() const {
|
||||
return dim_;
|
||||
}
|
||||
|
||||
private:
|
||||
int dim_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,35 +0,0 @@
|
||||
#include <torch/csrc/lazy/ts_backend/view_ops/view.h>
|
||||
|
||||
#include <ATen/InferSize.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
|
||||
const Shape& input_shape = input.shape();
|
||||
const auto complete_output_sizes =
|
||||
at::infer_size(output_sizes, input_shape.numel());
|
||||
return Shape(input_shape.scalar_type(), complete_output_sizes);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
View::View(const Value& input, std::vector<int64_t> output_size)
|
||||
: TsNode(
|
||||
ClassOpKind(),
|
||||
{input},
|
||||
{NodeOutputShape(input, output_size)},
|
||||
/*num_outputs=*/1,
|
||||
MHash(output_size)),
|
||||
output_size_(std::move(output_size)) {}
|
||||
|
||||
std::string View::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString() << ", output_size=(" << c10::Join(", ", output_size_)
|
||||
<< ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API View : public TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::view);
|
||||
}
|
||||
|
||||
View(const Value& input, std::vector<int64_t> output_size);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::vector<int64_t>& output_size() const {
|
||||
return output_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> output_size_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,4 +1,5 @@
|
||||
from typing import List, Union, Tuple, Optional
|
||||
from typing import Any, Dict, List, Union, Tuple, Optional
|
||||
|
||||
from torchgen.model import (
|
||||
Type,
|
||||
BaseTy,
|
||||
@ -56,7 +57,7 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
|
||||
|
||||
|
||||
def process_ir_type(
|
||||
typ: Type,
|
||||
typ: Type, properties: "LazyIrProperties"
|
||||
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
|
||||
"""
|
||||
This function takes a type from NativeFunctions and converts it for use with
|
||||
@ -77,6 +78,8 @@ def process_ir_type(
|
||||
if typ.name == BaseTy.Tensor:
|
||||
return BaseCType(getValueT())
|
||||
elif typ.name == BaseTy.Scalar:
|
||||
if properties.TreatScalarsAsConstants:
|
||||
return BaseCType(scalarT)
|
||||
# at::scalar has special handling,
|
||||
# and is wrapped in an lazy::Value just like at::tensor
|
||||
return BaseCType(getValueT())
|
||||
@ -101,7 +104,7 @@ def process_ir_type(
|
||||
else:
|
||||
raise AssertionError(f"TODO add support for type {repr(typ)}")
|
||||
elif isinstance(typ, OptionalType):
|
||||
return OptionalCType(process_ir_type(typ.elem))
|
||||
return OptionalCType(process_ir_type(typ.elem, properties))
|
||||
elif isinstance(typ, ListType):
|
||||
if str(typ.elem) == "Tensor?":
|
||||
# TODO(whc) is this actually correct? or should it use a Vector like above
|
||||
@ -110,12 +113,12 @@ def process_ir_type(
|
||||
# this is a TensorList which comes in from GetTensorList as a Value
|
||||
return BaseCType(tensorListValueT)
|
||||
else:
|
||||
return VectorCType(process_ir_type(typ.elem))
|
||||
return VectorCType(process_ir_type(typ.elem, properties))
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(typ)}")
|
||||
|
||||
|
||||
def isValueType(typ: CType) -> bool:
|
||||
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
|
||||
"""
|
||||
Given a type, determine if it is a Value-like type. This is equivalent to
|
||||
being Tensor-like, but assumes the type has already been transformed.
|
||||
@ -123,9 +126,14 @@ def isValueType(typ: CType) -> bool:
|
||||
if isinstance(typ, BaseCType):
|
||||
# I am regretting my naming conventions, but now we are wrapping at::scalar in
|
||||
# lazy value, while preserving other 'scalar' types as scalars in the IR
|
||||
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
|
||||
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
|
||||
return (
|
||||
typ.type == getValueT()
|
||||
or (typ.type == scalarT and not treat_scalars_as_constants)
|
||||
or typ.type == SymIntT
|
||||
)
|
||||
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
|
||||
return isValueType(typ.elem)
|
||||
return isValueType(typ.elem, properties)
|
||||
return False
|
||||
|
||||
|
||||
@ -167,7 +175,7 @@ class LazyArgument:
|
||||
# true if this argument is or contains a lazy IR value
|
||||
is_lazy_value: bool
|
||||
|
||||
def __init__(self, arg: Argument):
|
||||
def __init__(self, arg: Argument, properties: "LazyIrProperties"):
|
||||
self.name = arg.name
|
||||
self.orig_type = arg.type
|
||||
self.is_optional = isinstance(arg.type, OptionalType)
|
||||
@ -181,11 +189,13 @@ class LazyArgument:
|
||||
# its null and safe to exclude from lazy IR
|
||||
self.lazy_type_ = None
|
||||
else:
|
||||
self.lazy_type_ = process_ir_type(arg.type)
|
||||
self.lazy_type_ = process_ir_type(arg.type, properties)
|
||||
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
|
||||
self.is_symint_or_list = isSymIntType(arg.type)
|
||||
|
||||
self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)
|
||||
self.is_lazy_value = not self.is_generator and isValueType(
|
||||
self.lazy_type, properties
|
||||
)
|
||||
|
||||
@property
|
||||
def lazy_type(self) -> CType:
|
||||
@ -195,6 +205,64 @@ class LazyArgument:
|
||||
return self.lazy_type_
|
||||
|
||||
|
||||
class LazyIrProperties:
|
||||
"""Collection of properties for an IR node
|
||||
|
||||
The property groups are listed below. Each group is mutually
|
||||
exclusive, meaning that only one property from each group can be True
|
||||
at any one time. The properties can be accessed as if they were normal
|
||||
attributes. The mutual exclusivity is automatically handled.
|
||||
"""
|
||||
|
||||
Properties: Tuple[Tuple[str, ...], ...] = (
|
||||
(
|
||||
"ShapePrecompute", # Assume shape has been precomputed
|
||||
"ShapeCompute", # Need to compute the shape on construction
|
||||
"ShapeCache", # Utilize the shape cache to defer computation
|
||||
),
|
||||
(
|
||||
"Lower", # Codegen full lower function
|
||||
"LowerDeclOnly", # Codegen only lower function declaration
|
||||
),
|
||||
(
|
||||
"CanBeReused", # Codegen full reuse function
|
||||
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
|
||||
),
|
||||
(
|
||||
"CreateFn", # Codegen full create function
|
||||
"CreateFnDeclOnly", # Codegen only create function declaration
|
||||
),
|
||||
(
|
||||
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, *default_properties: str):
|
||||
properties: Dict[Tuple[str, ...], Optional[str]] = {
|
||||
p: None for p in LazyIrProperties.Properties
|
||||
}
|
||||
self.__dict__["properties"] = properties
|
||||
for p in default_properties:
|
||||
setattr(self, p, True)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
properties = self.__dict__["properties"]
|
||||
for values in LazyIrProperties.Properties:
|
||||
if key in values:
|
||||
return properties[values] == key
|
||||
|
||||
return self.__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> Any:
|
||||
properties = self.__dict__["properties"]
|
||||
for values in LazyIrProperties.Properties:
|
||||
if key in values:
|
||||
properties[values] = key if value else None
|
||||
return value
|
||||
|
||||
raise KeyError(f"Invalid property: {key}")
|
||||
|
||||
|
||||
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
|
||||
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
|
||||
# but carries type information from a native FunctionSchema modified for use with IR nodes,
|
||||
@ -213,20 +281,33 @@ class LazyIrSchema:
|
||||
# build a LazyArgument since lazy IR doesn't support it
|
||||
generator_arg: Optional[NamedCType] = None
|
||||
|
||||
def __init__(self, func: FunctionSchema):
|
||||
properties: LazyIrProperties = LazyIrProperties(
|
||||
# default properties
|
||||
"ShapePrecompute",
|
||||
"Lower",
|
||||
"CanBeReused",
|
||||
)
|
||||
opkind: Optional[str] = None
|
||||
|
||||
positional_args = []
|
||||
def __init__(
|
||||
self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None
|
||||
):
|
||||
if properties:
|
||||
self.properties = properties
|
||||
|
||||
positional_args: List[LazyArgument] = []
|
||||
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
||||
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
||||
arg = getattr(func.arguments, "self_arg").argument
|
||||
positional_args.append(LazyArgument(arg))
|
||||
positional_args.append(LazyArgument(arg, self.properties))
|
||||
elif getattr(func.arguments, arg_field) is not None:
|
||||
positional_args.extend(
|
||||
[LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
|
||||
LazyArgument(arg, self.properties)
|
||||
for arg in getattr(func.arguments, arg_field)
|
||||
)
|
||||
self.positional_args = tuple(positional_args)
|
||||
|
||||
keyword_args = []
|
||||
keyword_args: List[LazyArgument] = []
|
||||
for arg_field in [
|
||||
"pre_tensor_options_kwarg_only",
|
||||
"tensor_options",
|
||||
@ -243,7 +324,9 @@ class LazyIrSchema:
|
||||
self.generator_arg is None
|
||||
), "We expect there is only one generator arg"
|
||||
self.generator_arg = NamedCType(arg.name, arg.type)
|
||||
keyword_args.extend([LazyArgument(arg) for arg in curr_args])
|
||||
keyword_args.extend(
|
||||
LazyArgument(arg, self.properties) for arg in curr_args
|
||||
)
|
||||
self.keyword_args = tuple(keyword_args)
|
||||
self.name = func.name
|
||||
self.returns = func.returns
|
||||
@ -262,7 +345,7 @@ class LazyIrSchema:
|
||||
|
||||
@property
|
||||
def aten_name(self) -> str:
|
||||
return f"{self.name.name}"
|
||||
return str(self.name.name)
|
||||
|
||||
@property
|
||||
def base_name(self) -> str:
|
||||
|
@ -1,6 +1,9 @@
|
||||
from .lazy_ir import GenLazyIR as GenLazyIR
|
||||
from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition
|
||||
from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition
|
||||
from .lazy_ir import (
|
||||
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
|
||||
)
|
||||
from .register_dispatch_key import (
|
||||
RegisterDispatchKey as RegisterDispatchKey,
|
||||
gen_registration_helpers as gen_registration_helpers,
|
||||
|
@ -1,8 +1,13 @@
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
FunctionSchema,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
OptionalCType,
|
||||
@ -12,6 +17,7 @@ from torchgen.api.types import (
|
||||
)
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.lazy import (
|
||||
LazyIrProperties,
|
||||
LazyIrSchema,
|
||||
LazyArgument,
|
||||
getValueT,
|
||||
@ -108,36 +114,44 @@ def aten_symbol(schema: LazyIrSchema) -> str:
|
||||
}
|
||||
if schema.aten_name in missing_interned_strings:
|
||||
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
|
||||
return f"at::aten::{schema.aten_name}"
|
||||
|
||||
if not schema.aten_name.startswith("at::"):
|
||||
return f"at::aten::{schema.aten_name}"
|
||||
else:
|
||||
return schema.aten_name
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenLazyIR(ABC):
|
||||
backend_index: BackendIndex
|
||||
backend_name: str
|
||||
node_base: str
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
return self.gen(f)
|
||||
schema = LazyIrSchema(func)
|
||||
return self.gen(schema)
|
||||
|
||||
# there is no lowering functionality generated unless this IR base class is subclassed and
|
||||
# implemented as a backend-specific node
|
||||
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
def lowering_function(self, schema: LazyIrSchema) -> str:
|
||||
return ""
|
||||
|
||||
def can_be_reused_function(
|
||||
self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
|
||||
) -> str:
|
||||
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
||||
return ""
|
||||
|
||||
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
||||
return f"""bool CanBeReused({node_ctor_args}) const {{
|
||||
return false;
|
||||
}}"""
|
||||
|
||||
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
|
||||
value_args = schema.filtered_args(values=True, scalars=False)
|
||||
# backends can customize the way the node base class constructor is called,
|
||||
# as long as all of its arguments can be generated from information available from the schema
|
||||
base_ctor_value_args_list = []
|
||||
for arg in schema.filtered_args(values=True, scalars=False):
|
||||
for arg in value_args:
|
||||
if isinstance(arg.lazy_type, BaseCType) or isinstance(
|
||||
arg.lazy_type, VectorCType
|
||||
):
|
||||
@ -151,29 +165,51 @@ class GenLazyIR(ABC):
|
||||
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
|
||||
|
||||
scalar_args = schema.filtered_args(values=False, scalars=True)
|
||||
scalar_hashes = ", ".join([f"{a.name}" for a in scalar_args])
|
||||
|
||||
return f"""{self.node_base}(torch::lazy::OpKind({aten_symbol(schema)}),
|
||||
{{{base_ctor_value_args}}}, std::move(shapes),
|
||||
# Shape constuction.
|
||||
# Conditionally build shape depending on specified shape property
|
||||
if schema.properties.ShapePrecompute:
|
||||
shape_ctor_arg = "std::move(shapes),"
|
||||
elif schema.properties.ShapeCompute:
|
||||
shape_args = [a.name for a in value_args]
|
||||
shape_args.extend(a.name for a in scalar_args)
|
||||
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
|
||||
elif schema.properties.ShapeCache:
|
||||
shape_args = [f"operand({i})" for i in range(len(value_args))]
|
||||
shape_args.extend(a.name for a in scalar_args)
|
||||
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
|
||||
else:
|
||||
shape_ctor_arg = ""
|
||||
|
||||
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
|
||||
|
||||
return f"""{self.node_base}(
|
||||
{schema.node_name}::ClassOpKind(),
|
||||
OpList{{{base_ctor_value_args}}},
|
||||
{shape_ctor_arg}
|
||||
/* num_outputs */ {len(schema.returns)},
|
||||
torch::lazy::MHash({scalar_hashes}))"""
|
||||
|
||||
def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
def gen(self, schema: LazyIrSchema) -> List[str]:
|
||||
opkind = schema.opkind or aten_symbol(schema)
|
||||
|
||||
# for now, we just want one IR class decl and soon after also the method defs
|
||||
# and we use the functional version not out/inplace.
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
schema = LazyIrSchema(func)
|
||||
all_args = schema.filtered_args()
|
||||
value_args = schema.filtered_args(values=True, scalars=False)
|
||||
scalar_args = schema.filtered_args(values=False, scalars=True)
|
||||
|
||||
node_ctor_args = ", ".join(
|
||||
[f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
|
||||
)
|
||||
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
|
||||
reuse_ctor_args = ", ".join(ctor_args)
|
||||
if schema.properties.ShapePrecompute:
|
||||
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
|
||||
node_ctor_args = ", ".join(ctor_args)
|
||||
|
||||
scalar_initializers = ",\n ".join(
|
||||
[f"{a.name}({a.name})" for a in scalar_args]
|
||||
f"{a.name}({a.name})" for a in scalar_args
|
||||
)
|
||||
comma_if_scalar_initializers = ",\n" if len(scalar_initializers) else ""
|
||||
if len(scalar_initializers):
|
||||
scalar_initializers = f",\n {scalar_initializers}"
|
||||
scalar_decls = "\n ".join(
|
||||
[
|
||||
f"std::string {a.name};"
|
||||
@ -212,14 +248,11 @@ class GenLazyIR(ABC):
|
||||
class {schema.node_name} : public {self.node_base} {{
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {{
|
||||
return torch::lazy::OpKind({aten_symbol(schema)});
|
||||
return torch::lazy::OpKind({opkind});
|
||||
}}
|
||||
|
||||
{schema.node_name}({node_ctor_args}, std::vector<torch::lazy::Shape>&& shapes)
|
||||
|
||||
: {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers}
|
||||
{scalar_initializers}
|
||||
|
||||
{schema.node_name}({node_ctor_args})
|
||||
: {self.node_base_ctor_call(schema)}{scalar_initializers}
|
||||
{{
|
||||
{has_optional_defs}
|
||||
}}
|
||||
@ -231,9 +264,11 @@ class {schema.node_name} : public {self.node_base} {{
|
||||
return ss.str();
|
||||
}}
|
||||
|
||||
{self.can_be_reused_function(f, node_ctor_args)}
|
||||
{self.create_function(schema, reuse_ctor_args)}
|
||||
|
||||
{self.lowering_function(f)}
|
||||
{self.can_be_reused_function(schema, reuse_ctor_args)}
|
||||
|
||||
{self.lowering_function(schema)}
|
||||
|
||||
{scalar_decls}
|
||||
{has_optional_decls}
|
||||
@ -246,37 +281,57 @@ class {schema.node_name} : public {self.node_base} {{
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenTSLazyIR(GenLazyIR):
|
||||
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
return f"""torch::lazy::TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
torch::lazy::TSLoweringContext* loctx) const override {{
|
||||
{ts_lowering_body(f)}
|
||||
def lowering_function(self, schema: LazyIrSchema) -> str:
|
||||
signature = """
|
||||
torch::lazy::TSOpVector Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
torch::lazy::TSLoweringContext* loctx) const override"""
|
||||
|
||||
if schema.properties.LowerDeclOnly:
|
||||
return f"{signature};"
|
||||
elif schema.properties.Lower:
|
||||
return f"""{signature} {{
|
||||
{ts_lowering_body(schema)}
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
return ""
|
||||
|
||||
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
||||
signature = f"static NodePtr Create({node_ctor_args})"
|
||||
if schema.properties.CreateFnDeclOnly:
|
||||
return f"{signature};"
|
||||
elif not schema.properties.CreateFn:
|
||||
return ""
|
||||
return f"""{signature} {{
|
||||
return ReuseOrMakeNode<{schema.node_name}>(data);
|
||||
}}"""
|
||||
|
||||
def can_be_reused_function(
|
||||
self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
|
||||
) -> str:
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
schema = LazyIrSchema(func)
|
||||
|
||||
value_comparsion = []
|
||||
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
||||
signature = f"bool CanBeReused({node_ctor_args}) const"
|
||||
if schema.properties.CanBeReusedDeclOnly:
|
||||
return f"{signature};"
|
||||
elif not schema.properties.CanBeReused:
|
||||
return ""
|
||||
value_comparison = []
|
||||
for arg in schema.positional_values:
|
||||
if isinstance(arg.lazy_type, OptionalCType):
|
||||
value_comparsion.append(
|
||||
value_comparison.append(
|
||||
f"operand(i++) == {arg.name}.value_or(kNullValue)"
|
||||
)
|
||||
else:
|
||||
value_comparsion.append(f"operand(i++) == {arg.name}")
|
||||
value_comparison.append(f"operand(i++) == {arg.name}")
|
||||
for arg in schema.positional_scalars:
|
||||
value_comparsion.append(f"this->{arg.name} == {arg.name}")
|
||||
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
||||
for arg in schema.keyword_values:
|
||||
value_comparsion.append(f"operand(i++) == {arg.name}")
|
||||
value_comparison.append(f"operand(i++) == {arg.name}")
|
||||
for arg in schema.keyword_scalars:
|
||||
value_comparsion.append(f"this->{arg.name} == {arg.name}")
|
||||
value_comparsion_str = " &&\n ".join(value_comparsion)
|
||||
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
||||
value_comparison_str = " &&\n ".join(value_comparison)
|
||||
|
||||
return f"""bool CanBeReused({node_ctor_args}) const {{
|
||||
return f"""{signature} {{
|
||||
size_t i = 0;
|
||||
return ({value_comparsion_str});
|
||||
return ({value_comparison_str});
|
||||
}}"""
|
||||
|
||||
|
||||
@ -399,7 +454,7 @@ class GenLazyNativeFuncDefinition:
|
||||
shape_str += f"""
|
||||
if(torch::lazy::symbolicShapeEnabled()){{
|
||||
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
|
||||
char* schema_str = "{func_schema_str}";
|
||||
const char* schema_str = "{func_schema_str}";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}}
|
||||
"""
|
||||
@ -523,3 +578,21 @@ class GenLazyShapeInferenceDefinition:
|
||||
return ["\n".join([f"{shape_sig.shape_decl};"])]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def generate_non_native_lazy_ir_nodes(
|
||||
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
|
||||
) -> List[str]:
|
||||
"""Generate the non-native lazy IR node classes"""
|
||||
nodes = []
|
||||
for op in non_native:
|
||||
# Set default properties for Non-Native IRs
|
||||
properties = LazyIrProperties("ShapeCache")
|
||||
for p in op.get("properties", []):
|
||||
setattr(properties, p, True)
|
||||
|
||||
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties)
|
||||
schema.opkind = op.get("opkind")
|
||||
nodes.append(gen_lazy_ir.gen(schema)[0])
|
||||
|
||||
return nodes
|
||||
|
@ -1,15 +1,10 @@
|
||||
from typing import Union
|
||||
from torchgen.model import NativeFunction, NativeFunctionsGroup
|
||||
from torchgen.api.lazy import LazyIrSchema
|
||||
from torchgen.api.types import OptionalCType
|
||||
|
||||
|
||||
def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
def ts_lowering_body(schema: LazyIrSchema) -> str:
|
||||
# for now, we just want one IR class decl and soon after also the method defs
|
||||
# and we use the functional version not out/inplace.
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
schema = LazyIrSchema(func)
|
||||
|
||||
emplace_arguments = []
|
||||
for arg in schema.positional_args:
|
||||
if arg.is_lazy_value:
|
||||
@ -47,7 +42,7 @@ def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
{emplace_arguments_str}
|
||||
{emplace_kwarguments}
|
||||
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
|
||||
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
|
||||
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
|
||||
|
||||
return {schema.aten_name}_out;
|
||||
"""
|
||||
|
@ -61,6 +61,7 @@ def parse_backend_yaml(
|
||||
"supported",
|
||||
"autograd",
|
||||
"full_codegen",
|
||||
"non_native",
|
||||
]
|
||||
|
||||
backend = yaml_values.pop("backend", None)
|
||||
@ -98,6 +99,9 @@ def parse_backend_yaml(
|
||||
full_codegen = yaml_values.pop("full_codegen", [])
|
||||
supported.extend(full_codegen)
|
||||
|
||||
# non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
|
||||
non_native = yaml_values.pop("non_native", {})
|
||||
|
||||
assert (
|
||||
len(yaml_values.keys()) == 0
|
||||
), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
|
||||
|
@ -5,8 +5,10 @@ import re
|
||||
import yaml
|
||||
from collections import namedtuple, Counter
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Dict,
|
||||
Tuple,
|
||||
Union,
|
||||
Sequence,
|
||||
Optional,
|
||||
@ -106,10 +108,10 @@ ParsedExternalYaml = namedtuple(
|
||||
)
|
||||
|
||||
|
||||
def parse_full_codegen_ops(
|
||||
def parse_native_functions_keys(
|
||||
backend_yaml_path: str,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
) -> List[OperatorName]:
|
||||
) -> Tuple[List[OperatorName], List[Any]]:
|
||||
|
||||
native_functions_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f
|
||||
@ -124,12 +126,10 @@ def parse_full_codegen_ops(
|
||||
assert isinstance(yaml_values, dict)
|
||||
|
||||
full_codegen = yaml_values.pop("full_codegen", [])
|
||||
assert isinstance(
|
||||
full_codegen, list
|
||||
), f'expected "full_codegen" to be a list, but got: {full_codegen}'
|
||||
full_codegen = [OperatorName.parse(name) for name in full_codegen]
|
||||
|
||||
return full_codegen
|
||||
non_native = yaml_values.pop("non_native", [])
|
||||
assert isinstance(full_codegen, list)
|
||||
assert isinstance(non_native, list)
|
||||
return [OperatorName.parse(name) for name in full_codegen], non_native
|
||||
|
||||
|
||||
def validate_shape_inference_header(
|
||||
@ -150,13 +150,16 @@ def validate_shape_inference_header(
|
||||
)
|
||||
# TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
|
||||
|
||||
for decl in expected_shape_infr_decls:
|
||||
assert (
|
||||
decl in shape_infr_decl_lines
|
||||
), f"""Missing shape inference function.\n
|
||||
missing_decls = [
|
||||
decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
|
||||
]
|
||||
if missing_decls:
|
||||
raise Exception(
|
||||
f"""Missing shape inference function.\n
|
||||
Please add declare this function in {shape_inference_hdr}:\n
|
||||
and implement it in the the corresponding shape_inference.cpp file.\n
|
||||
{decl}"""
|
||||
{os.linesep.join(missing_decls)}"""
|
||||
)
|
||||
|
||||
|
||||
class default_args:
|
||||
@ -324,7 +327,9 @@ def run_gen_lazy_tensor(
|
||||
autograd_key = parsed_backend_yaml.autograd_key
|
||||
cpp_namespace = parsed_backend_yaml.cpp_namespace
|
||||
backend_indices = parsed_backend_yaml.backend_indices
|
||||
full_codegen = parse_full_codegen_ops(source_yaml, grouped_native_functions)
|
||||
full_codegen, non_native = parse_native_functions_keys(
|
||||
source_yaml, grouped_native_functions
|
||||
)
|
||||
|
||||
def concat_map_codegen(
|
||||
func: Callable[[NativeFunction], Sequence[str]],
|
||||
@ -476,6 +481,10 @@ def run_gen_lazy_tensor(
|
||||
},
|
||||
)
|
||||
# Generate IR node classes
|
||||
lazy_ir_obj = lazy_ir_generator(
|
||||
backend_indices[backend_key], backend_name, node_base
|
||||
)
|
||||
|
||||
fm.write_with_template(
|
||||
"LazyIr.h",
|
||||
"LazyIr.h",
|
||||
@ -492,16 +501,35 @@ def run_gen_lazy_tensor(
|
||||
"vector",
|
||||
]
|
||||
],
|
||||
"lazy_ir_inc": [
|
||||
f'#include "{path}"'
|
||||
for path in [node_base_hdr if node_base_hdr is not None else None]
|
||||
if path is not None
|
||||
],
|
||||
"lazy_ir_inc": [f'#include "{node_base_hdr}"']
|
||||
if node_base_hdr is not None
|
||||
else [],
|
||||
"ir_declarations": list(
|
||||
concat_map_codegen(
|
||||
lazy_ir_generator(backend_indices[backend_key], node_base),
|
||||
grouped_native_functions,
|
||||
)
|
||||
concat_map_codegen(lazy_ir_obj, grouped_native_functions)
|
||||
),
|
||||
"namespace_prologue": ns_helper.prologue,
|
||||
"namespace_epilogue": ns_helper.epilogue,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate Non Native IR Node classes
|
||||
fm.write_with_template(
|
||||
"LazyNonNativeIr.h",
|
||||
"LazyNonNativeIr.h",
|
||||
lambda: {
|
||||
"lazy_non_native_ir_inc": [
|
||||
f"#include <{path}>"
|
||||
for path in [
|
||||
"torch/csrc/lazy/core/ir.h",
|
||||
"torch/csrc/lazy/core/ir_builder.h",
|
||||
"torch/csrc/lazy/core/internal_ops/ltc_ops.h",
|
||||
"torch/csrc/lazy/core/shape_inference.h",
|
||||
]
|
||||
+ ([node_base_hdr] if node_base_hdr else [])
|
||||
if path
|
||||
],
|
||||
"non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
|
||||
non_native, lazy_ir_obj
|
||||
),
|
||||
"namespace_prologue": ns_helper.prologue,
|
||||
"namespace_epilogue": ns_helper.epilogue,
|
||||
|
@ -1072,18 +1072,14 @@ class FunctionSchema:
|
||||
self.arguments.out,
|
||||
)
|
||||
|
||||
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
||||
|
||||
@staticmethod
|
||||
def parse(func: str) -> "FunctionSchema":
|
||||
# We should probably get a proper parser here
|
||||
assert (
|
||||
" -> " in func
|
||||
), "function schema missing return type (spaces are mandatory)"
|
||||
last_index = func.rfind(" -> ")
|
||||
func_decl = func[:last_index]
|
||||
return_decl = func[last_index + len(" -> ") :]
|
||||
ops, args = func_decl.split("(", 1)
|
||||
assert args[-1] == ")", "Expecting closing )"
|
||||
args = args[:-1]
|
||||
decls = FunctionSchema.decl_re.findall(func)
|
||||
assert len(decls) == 1, f"Invalid function schema: {func}"
|
||||
ops, args, return_decl = decls[0]
|
||||
name = OperatorName.parse(ops)
|
||||
arguments = Arguments.parse(args)
|
||||
returns = parse_returns(return_decl)
|
||||
|
Reference in New Issue
Block a user