mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Autogen native_batch_norm and native_batch_norm_backward (#79637)
This PR makes the `native_batch_norm` and `native_batch_norm_backward` ops autogen, and implements their respective shape inference functions. Previously, these two ops were manually implemented. cc: @ke1337 @antoniojkim @wconstab @desertfire Pull Request resolved: https://github.com/pytorch/pytorch/pull/79637 Approved by: https://github.com/Gamrix, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdd3e20935
commit
02093da36c
@ -85,6 +85,8 @@ full_codegen:
|
||||
- mm
|
||||
- mul.Tensor
|
||||
- mv
|
||||
- native_batch_norm
|
||||
- native_batch_norm_backward
|
||||
- native_dropout
|
||||
- native_dropout_backward
|
||||
- native_layer_norm
|
||||
@ -153,8 +155,6 @@ supported:
|
||||
- expand
|
||||
- fill_.Scalar
|
||||
- narrow
|
||||
- native_batch_norm
|
||||
- native_batch_norm_backward
|
||||
- normal_
|
||||
- max_pool3d_with_indices
|
||||
- max_pool3d_with_indices_backward
|
||||
|
@ -420,7 +420,6 @@ lazy_tensor_core_sources = [
|
||||
lazy_tensor_ts_sources = [
|
||||
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
|
||||
"torch/csrc/lazy/ts_backend/config.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
|
||||
|
@ -121,7 +121,7 @@ class TestLazyReuseIr(TestCase):
|
||||
torch._lazy.mark_step()
|
||||
|
||||
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7
|
||||
metrics.reset()
|
||||
torch._lazy.ir_cache.reset()
|
||||
|
||||
|
@ -140,9 +140,11 @@ Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
|
||||
const std::vector<Output>& Node::operands() const {
|
||||
return operands_as_outputs_;
|
||||
}
|
||||
|
||||
const Output& Node::operand(size_t i) const {
|
||||
return operands_as_outputs_.at(i);
|
||||
}
|
||||
|
||||
const Output& Node::nullable_operand(size_t i) const {
|
||||
// We use kNullOutput instead of kNullValue here to avoid implicit casting,
|
||||
// which would prevent this method from returning a reference.
|
||||
|
@ -470,6 +470,77 @@ std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
|
||||
return {Shape(tensors[0].scalar_type(), out_shape)};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps) {
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.size(1);
|
||||
|
||||
if (running_mean.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
|
||||
if (running_var.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_var.value().scalar_type(), running_var.value().sizes().vec());
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
|
||||
const at::Tensor& grad_out,
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
const c10::optional<at::Tensor>& save_mean,
|
||||
const c10::optional<at::Tensor>& save_invstd,
|
||||
bool train,
|
||||
double eps,
|
||||
::std::array<bool, 3> output_mask) {
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.size(1);
|
||||
|
||||
// `weight` and `bias` are vectors of length C (number of channels)`
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_native_layer_norm(
|
||||
const at::Tensor& input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
|
@ -50,6 +50,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & s
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mv(const at::Tensor & self, const at::Tensor & vec);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||
|
@ -1,97 +0,0 @@
|
||||
#include <torch/csrc/lazy/core/util.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask)
|
||||
: torch::lazy::TsNode(
|
||||
torch::lazy::OpKind(at::aten::native_batch_norm_backward),
|
||||
{grad_out,
|
||||
input,
|
||||
weight,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd},
|
||||
{input.shape(), weight.shape(), weight.shape()},
|
||||
/*num_outputs=*/3,
|
||||
torch::lazy::MHash(
|
||||
training,
|
||||
eps,
|
||||
output_mask[0],
|
||||
output_mask[1],
|
||||
output_mask[2])),
|
||||
training_(training),
|
||||
eps_(eps),
|
||||
output_mask_(output_mask) {}
|
||||
|
||||
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask)
|
||||
: torch::lazy::TsNode(
|
||||
torch::lazy::OpKind(at::aten::native_batch_norm_backward),
|
||||
{grad_out, input, weight, save_mean, save_invstd},
|
||||
{input.shape(), weight.shape(), weight.shape()},
|
||||
/*num_outputs=*/3,
|
||||
torch::lazy::MHash(
|
||||
training,
|
||||
eps,
|
||||
output_mask[0],
|
||||
output_mask[1],
|
||||
output_mask[2])),
|
||||
training_(training),
|
||||
eps_(eps),
|
||||
output_mask_(output_mask) {}
|
||||
|
||||
std::string TSNativeBatchNormBackward::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TsNode::ToString() << ", training=" << training_
|
||||
<< ", eps=" << eps_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
TSNativeBatchNormForward::TSNativeBatchNormForward(
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& bias,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps)
|
||||
: torch::lazy::TsNode(
|
||||
torch::lazy::OpKind(at::aten::native_batch_norm),
|
||||
{input, weight, bias, running_mean, running_var},
|
||||
{input.shape(), running_mean.shape(), running_var.shape()},
|
||||
/*num_outputs=*/3,
|
||||
torch::lazy::MHash(training, momentum, eps)),
|
||||
training_(training),
|
||||
momentum_(momentum),
|
||||
eps_(eps) {}
|
||||
|
||||
std::string TSNativeBatchNormForward::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TsNode::ToString() << ", training=" << training_
|
||||
<< ", momentum=" << momentum_ << ", eps=" << eps_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,156 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// Node for the backward batch norm operator.
|
||||
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::native_batch_norm_backward);
|
||||
}
|
||||
|
||||
TSNativeBatchNormBackward(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask);
|
||||
|
||||
TSNativeBatchNormBackward(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask);
|
||||
|
||||
bool CanBeReused(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask) const {
|
||||
size_t i = 0;
|
||||
return (
|
||||
operand(i++) == grad_out && operand(i++) == input &&
|
||||
operand(i++) == weight && operand(i++) == running_mean &&
|
||||
operand(i++) == running_var && operand(i++) == save_mean &&
|
||||
operand(i++) == save_invstd && training_ == training && eps_ == eps &&
|
||||
output_mask_ == output_mask);
|
||||
}
|
||||
|
||||
bool CanBeReused(
|
||||
const torch::lazy::Value& grad_out,
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& save_mean,
|
||||
const torch::lazy::Value& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask) const {
|
||||
size_t i = 0;
|
||||
return (
|
||||
operand(i++) == grad_out && operand(i++) == input &&
|
||||
operand(i++) == weight && operand(i++) == save_mean &&
|
||||
operand(i++) == save_invstd && training_ == training && eps_ == eps &&
|
||||
output_mask_ == output_mask);
|
||||
}
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool training() const {
|
||||
return training_;
|
||||
}
|
||||
|
||||
double eps() const {
|
||||
return eps_;
|
||||
}
|
||||
|
||||
const std::array<bool, 3>& output_mask() const {
|
||||
return output_mask_;
|
||||
}
|
||||
|
||||
TSOpVector Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
TSLoweringContext* loctx) const override;
|
||||
|
||||
private:
|
||||
bool training_;
|
||||
double eps_;
|
||||
std::array<bool, 3> output_mask_;
|
||||
};
|
||||
|
||||
class TSNativeBatchNormForward : public torch::lazy::TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind(at::aten::native_batch_norm);
|
||||
}
|
||||
|
||||
TSNativeBatchNormForward(
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& bias,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps);
|
||||
|
||||
bool CanBeReused(
|
||||
const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& bias,
|
||||
const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps) const {
|
||||
size_t i = 0;
|
||||
return (
|
||||
operand(i++) == input && operand(i++) == weight &&
|
||||
operand(i++) == bias && operand(i++) == running_mean &&
|
||||
operand(i++) == running_var && training_ == training &&
|
||||
momentum_ == momentum && eps == eps_);
|
||||
}
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool training() const {
|
||||
return training_;
|
||||
}
|
||||
|
||||
double momentum() const {
|
||||
return momentum_;
|
||||
}
|
||||
|
||||
double eps() const {
|
||||
return eps_;
|
||||
}
|
||||
|
||||
TSOpVector Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
TSLoweringContext* loctx) const override;
|
||||
|
||||
private:
|
||||
bool training_;
|
||||
double momentum_;
|
||||
double eps_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -13,7 +13,6 @@
|
||||
#include <torch/csrc/lazy/core/tensor.h>
|
||||
#include <torch/csrc/lazy/core/util.h>
|
||||
#include <torch/csrc/lazy/generated/LazyIr.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
@ -50,27 +49,6 @@ std::vector<int64_t> GetExpandDimensions(
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
|
||||
torch::lazy::Shape BatchNormFeaturesShape(
|
||||
const torch::lazy::LazyTensorPtr& input) {
|
||||
CHECK(input);
|
||||
auto input_shape = input->shape().Get();
|
||||
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
|
||||
}
|
||||
|
||||
// Returns the IR for the given input or the provided default value broadcasted
|
||||
// to the default shape, if the input is undefined.
|
||||
torch::lazy::Value GetIrValueOrDefault(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
const at::Scalar& default_value,
|
||||
const torch::lazy::Shape& default_shape,
|
||||
const torch::lazy::BackendDevice& device) {
|
||||
return input
|
||||
? input->GetIrValue()
|
||||
: torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
|
||||
default_value, default_shape, device);
|
||||
}
|
||||
|
||||
torch::lazy::ViewInfo CreateAsStridedViewInfo(
|
||||
const torch::lazy::Shape& input_shape,
|
||||
std::vector<int64_t> size,
|
||||
@ -165,105 +143,6 @@ torch::lazy::LazyTensorPtr narrow(
|
||||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
std::tuple<
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr>
|
||||
ts_native_batch_norm(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
const torch::lazy::LazyTensorPtr& weight,
|
||||
const torch::lazy::LazyTensorPtr& bias,
|
||||
torch::lazy::LazyTensorPtr& running_mean,
|
||||
torch::lazy::LazyTensorPtr& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps) {
|
||||
torch::lazy::Shape features_shape = BatchNormFeaturesShape(input);
|
||||
torch::lazy::Value weight_value =
|
||||
GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice());
|
||||
torch::lazy::Value bias_value =
|
||||
GetIrValueOrDefault(bias, 0, features_shape, input->GetDevice());
|
||||
torch::lazy::Value running_mean_value =
|
||||
GetIrValueOrDefault(running_mean, 0, features_shape, input->GetDevice());
|
||||
torch::lazy::Value running_var_value =
|
||||
GetIrValueOrDefault(running_var, 0, features_shape, input->GetDevice());
|
||||
torch::lazy::NodePtr node = ReuseOrMakeNode<TSNativeBatchNormForward>(
|
||||
input->GetIrValue(),
|
||||
weight_value,
|
||||
bias_value,
|
||||
running_mean_value,
|
||||
running_var_value,
|
||||
training,
|
||||
momentum,
|
||||
eps);
|
||||
torch::lazy::LazyTensorPtr output = torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 0), input->GetDevice());
|
||||
torch::lazy::LazyTensorPtr running_mean_output =
|
||||
torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 1), input->GetDevice());
|
||||
torch::lazy::LazyTensorPtr running_var_output =
|
||||
torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 2), input->GetDevice());
|
||||
return std::make_tuple(
|
||||
std::move(output),
|
||||
std::move(running_mean_output),
|
||||
std::move(running_var_output));
|
||||
}
|
||||
|
||||
std::tuple<
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr>
|
||||
ts_native_batch_norm_backward(
|
||||
const torch::lazy::LazyTensorPtr& grad_out,
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
const torch::lazy::LazyTensorPtr& weight,
|
||||
const torch::lazy::LazyTensorPtr& running_mean,
|
||||
const torch::lazy::LazyTensorPtr& running_var,
|
||||
const torch::lazy::LazyTensorPtr& save_mean,
|
||||
const torch::lazy::LazyTensorPtr& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
c10::ArrayRef<bool> output_mask) {
|
||||
torch::lazy::Shape features_shape = BatchNormFeaturesShape(input);
|
||||
torch::lazy::Value weight_value =
|
||||
GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice());
|
||||
torch::lazy::NodePtr node;
|
||||
if (!running_mean && !running_var) {
|
||||
node = ReuseOrMakeNode<TSNativeBatchNormBackward>(
|
||||
grad_out->GetIrValue(),
|
||||
input->GetIrValue(),
|
||||
weight_value,
|
||||
save_mean->GetIrValue(),
|
||||
save_invstd->GetIrValue(),
|
||||
training,
|
||||
eps,
|
||||
std::array<bool, 3>{output_mask[0], output_mask[1], output_mask[2]});
|
||||
} else {
|
||||
CHECK(running_mean);
|
||||
CHECK(running_var);
|
||||
node = ReuseOrMakeNode<TSNativeBatchNormBackward>(
|
||||
grad_out->GetIrValue(),
|
||||
input->GetIrValue(),
|
||||
weight_value,
|
||||
running_mean->GetIrValue(),
|
||||
running_var->GetIrValue(),
|
||||
save_mean->GetIrValue(),
|
||||
save_invstd->GetIrValue(),
|
||||
training,
|
||||
eps,
|
||||
std::array<bool, 3>{output_mask[0], output_mask[1], output_mask[2]});
|
||||
}
|
||||
torch::lazy::LazyTensorPtr grad_input = torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 0), input->GetDevice());
|
||||
torch::lazy::LazyTensorPtr grad_weight = torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 1), input->GetDevice());
|
||||
torch::lazy::LazyTensorPtr grad_bias = torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::Value(node, 2), input->GetDevice());
|
||||
return std::make_tuple(
|
||||
std::move(grad_input), std::move(grad_weight), std::move(grad_bias));
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr permute(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
c10::ArrayRef<int64_t> dims) {
|
||||
|
@ -38,36 +38,6 @@ torch::lazy::LazyTensorPtr narrow(
|
||||
int64_t start,
|
||||
int64_t length);
|
||||
|
||||
std::tuple<
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr>
|
||||
ts_native_batch_norm(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
const torch::lazy::LazyTensorPtr& weight,
|
||||
const torch::lazy::LazyTensorPtr& bias,
|
||||
torch::lazy::LazyTensorPtr& running_mean,
|
||||
torch::lazy::LazyTensorPtr& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps);
|
||||
|
||||
std::tuple<
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr,
|
||||
torch::lazy::LazyTensorPtr>
|
||||
ts_native_batch_norm_backward(
|
||||
const torch::lazy::LazyTensorPtr& grad_out,
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
const torch::lazy::LazyTensorPtr& weight,
|
||||
const torch::lazy::LazyTensorPtr& running_mean,
|
||||
const torch::lazy::LazyTensorPtr& running_var,
|
||||
const torch::lazy::LazyTensorPtr& save_mean,
|
||||
const torch::lazy::LazyTensorPtr& save_invstd,
|
||||
bool training,
|
||||
double eps,
|
||||
c10::ArrayRef<bool> output_mask);
|
||||
|
||||
// Permute the dimensions of this tensor according to the given permutation.
|
||||
torch::lazy::LazyTensorPtr permute(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
|
@ -374,78 +374,6 @@ at::Tensor LazyNativeFunctions::max_pool3d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::
|
||||
native_batch_norm(
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
bool training,
|
||||
double momentum,
|
||||
double eps) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
|
||||
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
|
||||
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
|
||||
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
|
||||
auto outputs = ts_native_batch_norm(
|
||||
torch::lazy::TryGetLtcTensor(input),
|
||||
GetOrCreateLtcTensor(weight, device),
|
||||
GetOrCreateLtcTensor(bias, device),
|
||||
running_mean_tensor,
|
||||
running_var_tensor,
|
||||
training,
|
||||
momentum,
|
||||
eps);
|
||||
return std::make_tuple(
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::
|
||||
native_batch_norm_backward(
|
||||
const at::Tensor& grad_out,
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
const c10::optional<at::Tensor>& save_mean,
|
||||
const c10::optional<at::Tensor>& save_invstd,
|
||||
bool train,
|
||||
double eps,
|
||||
std::array<bool, 3> output_mask) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
|
||||
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
|
||||
torch::lazy::LazyTensorPtr null_tensor;
|
||||
bool running_stats = running_mean && running_mean->defined();
|
||||
CHECK_EQ(running_var && running_var->defined(), running_stats);
|
||||
auto gradients = ts_native_batch_norm_backward(
|
||||
torch::lazy::TryGetLtcTensor(grad_out),
|
||||
torch::lazy::TryGetLtcTensor(input),
|
||||
GetOrCreateLtcTensor(weight, device),
|
||||
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
|
||||
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
|
||||
GetOrCreateLtcTensor(save_mean, device),
|
||||
GetOrCreateLtcTensor(save_invstd, device),
|
||||
train,
|
||||
eps,
|
||||
output_mask);
|
||||
at::Tensor undefined;
|
||||
return std::make_tuple(
|
||||
output_mask[0]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
|
||||
: undefined,
|
||||
output_mask[1]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
|
||||
: undefined,
|
||||
output_mask[2]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
|
||||
: undefined);
|
||||
}
|
||||
|
||||
// We need to explicitly override max pooling operators and just call the
|
||||
// fallback for them because we've customized the autograd function for them
|
||||
// (backward needs saved indices from forward).
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/core/permutation_util.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ir_builder.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
|
||||
|
||||
namespace torch {
|
||||
@ -107,41 +106,6 @@ TSOpVector TsNode::Lower(
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
// TS specific ops
|
||||
TSOpVector TSNativeBatchNormForward::Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
TSLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
|
||||
}
|
||||
arguments.emplace_back(training_);
|
||||
arguments.emplace_back(momentum_);
|
||||
arguments.emplace_back(eps_);
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TSOpVector TSNativeBatchNormBackward::Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
TSLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
|
||||
}
|
||||
c10::optional<at::Tensor> null_arg;
|
||||
if (operands().size() == 5) {
|
||||
arguments.emplace_back(null_arg);
|
||||
arguments.emplace_back(null_arg);
|
||||
}
|
||||
for (size_t i = 3; i < operands().size(); ++i) {
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i)));
|
||||
}
|
||||
arguments.emplace_back(training_);
|
||||
arguments.emplace_back(eps_);
|
||||
arguments.emplace_back(output_mask_);
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
// Non-native ops
|
||||
torch::lazy::TSOpVector Cast::Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
|
Reference in New Issue
Block a user