Autogen native_batch_norm and native_batch_norm_backward (#79637)

This PR makes the `native_batch_norm` and `native_batch_norm_backward` ops autogen, and implements their respective shape inference functions.

Previously, these two ops were manually implemented.

cc: @ke1337 @antoniojkim @wconstab @desertfire
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79637
Approved by: https://github.com/Gamrix, https://github.com/desertfire
This commit is contained in:
Henry Tu
2022-06-24 23:29:33 +00:00
committed by PyTorch MergeBot
parent fdd3e20935
commit 02093da36c
12 changed files with 78 additions and 516 deletions

View File

@ -85,6 +85,8 @@ full_codegen:
- mm
- mul.Tensor
- mv
- native_batch_norm
- native_batch_norm_backward
- native_dropout
- native_dropout_backward
- native_layer_norm
@ -153,8 +155,6 @@ supported:
- expand
- fill_.Scalar
- narrow
- native_batch_norm
- native_batch_norm_backward
- normal_
- max_pool3d_with_indices
- max_pool3d_with_indices_backward

View File

@ -420,7 +420,6 @@ lazy_tensor_core_sources = [
lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",

View File

@ -121,7 +121,7 @@ class TestLazyReuseIr(TestCase):
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
assert metrics.counter_value("IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7
assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7
metrics.reset()
torch._lazy.ir_cache.reset()

View File

@ -140,9 +140,11 @@ Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
const std::vector<Output>& Node::operands() const {
return operands_as_outputs_;
}
const Output& Node::operand(size_t i) const {
return operands_as_outputs_.at(i);
}
const Output& Node::nullable_operand(size_t i) const {
// We use kNullOutput instead of kNullValue here to avoid implicit casting,
// which would prevent this method from returning a reference.

View File

@ -470,6 +470,77 @@ std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
return {Shape(tensors[0].scalar_type(), out_shape)};
}
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
bool training,
double momentum,
double eps) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.size(1);
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
const at::Tensor& grad_out,
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd,
bool train,
double eps,
::std::array<bool, 3> output_mask) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.size(1);
// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
return shapes;
}
std::vector<Shape> compute_shape_native_layer_norm(
const at::Tensor& input,
at::IntArrayRef normalized_shape,

View File

@ -50,6 +50,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & s
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mv(const at::Tensor & self, const at::Tensor & vec);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);

View File

@ -1,97 +0,0 @@
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
namespace torch {
namespace lazy {
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask)
: torch::lazy::TsNode(
torch::lazy::OpKind(at::aten::native_batch_norm_backward),
{grad_out,
input,
weight,
running_mean,
running_var,
save_mean,
save_invstd},
{input.shape(), weight.shape(), weight.shape()},
/*num_outputs=*/3,
torch::lazy::MHash(
training,
eps,
output_mask[0],
output_mask[1],
output_mask[2])),
training_(training),
eps_(eps),
output_mask_(output_mask) {}
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask)
: torch::lazy::TsNode(
torch::lazy::OpKind(at::aten::native_batch_norm_backward),
{grad_out, input, weight, save_mean, save_invstd},
{input.shape(), weight.shape(), weight.shape()},
/*num_outputs=*/3,
torch::lazy::MHash(
training,
eps,
output_mask[0],
output_mask[1],
output_mask[2])),
training_(training),
eps_(eps),
output_mask_(output_mask) {}
std::string TSNativeBatchNormBackward::ToString() const {
std::stringstream ss;
ss << torch::lazy::TsNode::ToString() << ", training=" << training_
<< ", eps=" << eps_;
return ss.str();
}
TSNativeBatchNormForward::TSNativeBatchNormForward(
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& bias,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
bool training,
double momentum,
double eps)
: torch::lazy::TsNode(
torch::lazy::OpKind(at::aten::native_batch_norm),
{input, weight, bias, running_mean, running_var},
{input.shape(), running_mean.shape(), running_var.shape()},
/*num_outputs=*/3,
torch::lazy::MHash(training, momentum, eps)),
training_(training),
momentum_(momentum),
eps_(eps) {}
std::string TSNativeBatchNormForward::ToString() const {
std::stringstream ss;
ss << torch::lazy::TsNode::ToString() << ", training=" << training_
<< ", momentum=" << momentum_ << ", eps=" << eps_;
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -1,156 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
// Node for the backward batch norm operator.
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::native_batch_norm_backward);
}
TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask);
TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask);
bool CanBeReused(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask) const {
size_t i = 0;
return (
operand(i++) == grad_out && operand(i++) == input &&
operand(i++) == weight && operand(i++) == running_mean &&
operand(i++) == running_var && operand(i++) == save_mean &&
operand(i++) == save_invstd && training_ == training && eps_ == eps &&
output_mask_ == output_mask);
}
bool CanBeReused(
const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& save_mean,
const torch::lazy::Value& save_invstd,
bool training,
double eps,
std::array<bool, 3> output_mask) const {
size_t i = 0;
return (
operand(i++) == grad_out && operand(i++) == input &&
operand(i++) == weight && operand(i++) == save_mean &&
operand(i++) == save_invstd && training_ == training && eps_ == eps &&
output_mask_ == output_mask);
}
std::string ToString() const override;
bool training() const {
return training_;
}
double eps() const {
return eps_;
}
const std::array<bool, 3>& output_mask() const {
return output_mask_;
}
TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;
private:
bool training_;
double eps_;
std::array<bool, 3> output_mask_;
};
class TSNativeBatchNormForward : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::native_batch_norm);
}
TSNativeBatchNormForward(
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& bias,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
bool training,
double momentum,
double eps);
bool CanBeReused(
const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& bias,
const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var,
bool training,
double momentum,
double eps) const {
size_t i = 0;
return (
operand(i++) == input && operand(i++) == weight &&
operand(i++) == bias && operand(i++) == running_mean &&
operand(i++) == running_var && training_ == training &&
momentum_ == momentum && eps == eps_);
}
std::string ToString() const override;
bool training() const {
return training_;
}
double momentum() const {
return momentum_;
}
double eps() const {
return eps_;
}
TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;
private:
bool training_;
double momentum_;
double eps_;
};
} // namespace lazy
} // namespace torch

View File

@ -13,7 +13,6 @@
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/generated/LazyIr.h>
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
#include <algorithm>
#include <functional>
@ -50,27 +49,6 @@ std::vector<int64_t> GetExpandDimensions(
return dimensions;
}
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
torch::lazy::Shape BatchNormFeaturesShape(
const torch::lazy::LazyTensorPtr& input) {
CHECK(input);
auto input_shape = input->shape().Get();
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
}
// Returns the IR for the given input or the provided default value broadcasted
// to the default shape, if the input is undefined.
torch::lazy::Value GetIrValueOrDefault(
const torch::lazy::LazyTensorPtr& input,
const at::Scalar& default_value,
const torch::lazy::Shape& default_shape,
const torch::lazy::BackendDevice& device) {
return input
? input->GetIrValue()
: torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
default_value, default_shape, device);
}
torch::lazy::ViewInfo CreateAsStridedViewInfo(
const torch::lazy::Shape& input_shape,
std::vector<int64_t> size,
@ -165,105 +143,6 @@ torch::lazy::LazyTensorPtr narrow(
return input->CreateViewTensor(std::move(view_info));
}
std::tuple<
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
ts_native_batch_norm(
const torch::lazy::LazyTensorPtr& input,
const torch::lazy::LazyTensorPtr& weight,
const torch::lazy::LazyTensorPtr& bias,
torch::lazy::LazyTensorPtr& running_mean,
torch::lazy::LazyTensorPtr& running_var,
bool training,
double momentum,
double eps) {
torch::lazy::Shape features_shape = BatchNormFeaturesShape(input);
torch::lazy::Value weight_value =
GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice());
torch::lazy::Value bias_value =
GetIrValueOrDefault(bias, 0, features_shape, input->GetDevice());
torch::lazy::Value running_mean_value =
GetIrValueOrDefault(running_mean, 0, features_shape, input->GetDevice());
torch::lazy::Value running_var_value =
GetIrValueOrDefault(running_var, 0, features_shape, input->GetDevice());
torch::lazy::NodePtr node = ReuseOrMakeNode<TSNativeBatchNormForward>(
input->GetIrValue(),
weight_value,
bias_value,
running_mean_value,
running_var_value,
training,
momentum,
eps);
torch::lazy::LazyTensorPtr output = torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 0), input->GetDevice());
torch::lazy::LazyTensorPtr running_mean_output =
torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 1), input->GetDevice());
torch::lazy::LazyTensorPtr running_var_output =
torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 2), input->GetDevice());
return std::make_tuple(
std::move(output),
std::move(running_mean_output),
std::move(running_var_output));
}
std::tuple<
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
ts_native_batch_norm_backward(
const torch::lazy::LazyTensorPtr& grad_out,
const torch::lazy::LazyTensorPtr& input,
const torch::lazy::LazyTensorPtr& weight,
const torch::lazy::LazyTensorPtr& running_mean,
const torch::lazy::LazyTensorPtr& running_var,
const torch::lazy::LazyTensorPtr& save_mean,
const torch::lazy::LazyTensorPtr& save_invstd,
bool training,
double eps,
c10::ArrayRef<bool> output_mask) {
torch::lazy::Shape features_shape = BatchNormFeaturesShape(input);
torch::lazy::Value weight_value =
GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice());
torch::lazy::NodePtr node;
if (!running_mean && !running_var) {
node = ReuseOrMakeNode<TSNativeBatchNormBackward>(
grad_out->GetIrValue(),
input->GetIrValue(),
weight_value,
save_mean->GetIrValue(),
save_invstd->GetIrValue(),
training,
eps,
std::array<bool, 3>{output_mask[0], output_mask[1], output_mask[2]});
} else {
CHECK(running_mean);
CHECK(running_var);
node = ReuseOrMakeNode<TSNativeBatchNormBackward>(
grad_out->GetIrValue(),
input->GetIrValue(),
weight_value,
running_mean->GetIrValue(),
running_var->GetIrValue(),
save_mean->GetIrValue(),
save_invstd->GetIrValue(),
training,
eps,
std::array<bool, 3>{output_mask[0], output_mask[1], output_mask[2]});
}
torch::lazy::LazyTensorPtr grad_input = torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 0), input->GetDevice());
torch::lazy::LazyTensorPtr grad_weight = torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 1), input->GetDevice());
torch::lazy::LazyTensorPtr grad_bias = torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, 2), input->GetDevice());
return std::make_tuple(
std::move(grad_input), std::move(grad_weight), std::move(grad_bias));
}
torch::lazy::LazyTensorPtr permute(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> dims) {

View File

@ -38,36 +38,6 @@ torch::lazy::LazyTensorPtr narrow(
int64_t start,
int64_t length);
std::tuple<
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
ts_native_batch_norm(
const torch::lazy::LazyTensorPtr& input,
const torch::lazy::LazyTensorPtr& weight,
const torch::lazy::LazyTensorPtr& bias,
torch::lazy::LazyTensorPtr& running_mean,
torch::lazy::LazyTensorPtr& running_var,
bool training,
double momentum,
double eps);
std::tuple<
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
ts_native_batch_norm_backward(
const torch::lazy::LazyTensorPtr& grad_out,
const torch::lazy::LazyTensorPtr& input,
const torch::lazy::LazyTensorPtr& weight,
const torch::lazy::LazyTensorPtr& running_mean,
const torch::lazy::LazyTensorPtr& running_var,
const torch::lazy::LazyTensorPtr& save_mean,
const torch::lazy::LazyTensorPtr& save_invstd,
bool training,
double eps,
c10::ArrayRef<bool> output_mask);
// Permute the dimensions of this tensor according to the given permutation.
torch::lazy::LazyTensorPtr permute(
const torch::lazy::LazyTensorPtr& input,

View File

@ -374,78 +374,6 @@ at::Tensor LazyNativeFunctions::max_pool3d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::
native_batch_norm(
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
bool training,
double momentum,
double eps) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
auto outputs = ts_native_batch_norm(
torch::lazy::TryGetLtcTensor(input),
GetOrCreateLtcTensor(weight, device),
GetOrCreateLtcTensor(bias, device),
running_mean_tensor,
running_var_tensor,
training,
momentum,
eps);
return std::make_tuple(
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::
native_batch_norm_backward(
const at::Tensor& grad_out,
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd,
bool train,
double eps,
std::array<bool, 3> output_mask) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
torch::lazy::LazyTensorPtr null_tensor;
bool running_stats = running_mean && running_mean->defined();
CHECK_EQ(running_var && running_var->defined(), running_stats);
auto gradients = ts_native_batch_norm_backward(
torch::lazy::TryGetLtcTensor(grad_out),
torch::lazy::TryGetLtcTensor(input),
GetOrCreateLtcTensor(weight, device),
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
GetOrCreateLtcTensor(save_mean, device),
GetOrCreateLtcTensor(save_invstd, device),
train,
eps,
output_mask);
at::Tensor undefined;
return std::make_tuple(
output_mask[0]
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
: undefined,
output_mask[1]
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
: undefined,
output_mask[2]
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
: undefined);
}
// We need to explicitly override max pooling operators and just call the
// fallback for them because we've customized the autograd function for them
// (backward needs saved indices from forward).

View File

@ -11,7 +11,6 @@
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/permutation_util.h>
#include <torch/csrc/lazy/ts_backend/ir_builder.h>
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
namespace torch {
@ -107,41 +106,6 @@ TSOpVector TsNode::Lower(
return LowerBuiltin(this, function, arguments);
}
// TS specific ops
TSOpVector TSNativeBatchNormForward::Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 5; ++i) {
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
}
arguments.emplace_back(training_);
arguments.emplace_back(momentum_);
arguments.emplace_back(eps_);
return LowerBuiltin(this, function, arguments);
}
TSOpVector TSNativeBatchNormBackward::Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 3; ++i) {
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
}
c10::optional<at::Tensor> null_arg;
if (operands().size() == 5) {
arguments.emplace_back(null_arg);
arguments.emplace_back(null_arg);
}
for (size_t i = 3; i < operands().size(); ++i) {
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
}
arguments.emplace_back(training_);
arguments.emplace_back(eps_);
arguments.emplace_back(output_mask_);
return LowerBuiltin(this, function, arguments);
}
// Non-native ops
torch::lazy::TSOpVector Cast::Lower(
std::shared_ptr<torch::jit::GraphFunction> function,