ONNX Export Scatter

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18543

Differential Revision: D14658639

Pulled By: houseroad

fbshipit-source-id: 5d7821b54d2fc93f71120155adf328897d13aff6
This commit is contained in:
Lara
2019-05-22 13:25:45 -07:00
committed by Facebook Github Bot
parent fea4a56af3
commit 8d7a025703
4 changed files with 163 additions and 0 deletions

View File

@ -51,6 +51,7 @@ REGISTER_CPU_OPERATOR(
ScatterWeightedSum,
ScatterWeightedSumOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
REGISTER_CPU_OPERATOR(Scatter, ScatterOp<CPUContext>);
REGISTER_CPU_OPERATOR(LengthsToShape, LengthsToShapeOp<CPUContext>);
REGISTER_CPU_OPERATOR(HasElements, HasElementsOp<CPUContext>);
@ -369,6 +370,38 @@ Currently only works on CPU because of access to INDICES.
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
.Output(0, "DATA", "Has to be exactly the same tensor as the input 0");
OPERATOR_SCHEMA(Scatter)
.NumInputs(3)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Update values of the tensor by overriding current value specified by indices.
Writes all values from the tensor UPDATES into DATA at the indices specified in the INDICES tensor.
For each value in DATA, its output index is specified by its index in UPDATES and by the corresponding value in INDICES for the specified axis.
For a 3-D tensor, DATA is updated as:
DATA[INDICES[i][j][k]][j][k] = UPDATES[i][j][k] # if axis == 0
DATA[i][INDICES[i][j][k]][k] = UPDATES[i][j][k] # if axis == 1
DATA[i][j][INDICES[i][j][k]] = UPDATES[i][j][k] # if axis == 2
Currently only works on CPU because of access to INDICES.
)DOC")
.Input(0, "DATA", "Tensor to be updated.")
.Input(
1,
"INDICES",
"1-D list of indices on the first dimension"
"of X_0 that need to be updated")
.Input(
2,
"UPDATES",
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
.Output(0, "OUTPUT", "The updated output.")
.Arg(
"axis",
"*(type: int; default: 1)* Which dimension to scatter on.");
OPERATOR_SCHEMA(HasElements)
.NumInputs(1)
@ -739,6 +772,7 @@ REGISTER_GRADIENT(Sum, GetSumGradient);
SHOULD_NOT_DO_GRADIENT(ScatterWeightedSum);
SHOULD_NOT_DO_GRADIENT(ScatterAssign);
SHOULD_NOT_DO_GRADIENT(Scatter);
class GetWeightedSumGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;

View File

@ -738,6 +738,106 @@ class ScatterAssignOp : public Operator<Context> {
INPUT_TAGS(DATA, INDICES, SLICES);
};
template <class Context>
class ScatterOp : public Operator<CPUContext> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ScatterOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "axis", axis_, 1) {
}
virtual ~ScatterOp() noexcept override {}
bool RunOnDevice() override {
TORCH_CHECK(Context::GetDeviceType() == kCPU, "ScatterOp currently only supports CPU.")
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, this->template Input<Tensor>(INDICES, CPU));
}
template <typename IndexType>
bool DoRunWithType() {
const Tensor& data = Input(DATA);
const Tensor& indices = Input(INDICES);
const Tensor& updates = Input(UPDATES);
const TypeMeta dataType = data.dtype();
size_t item_bytesize = dataType.itemsize();
// ONNX allows negative axis to index from the back, valid range: [-r, r].
axis_ = data.canonical_axis_index(axis_);
CAFFE_ENFORCE_GE(data.dim(), axis_ + 1, "DATA should be at least [axis+1]-D");
CAFFE_ENFORCE_GE(axis_, 0, "Axis should be non-negative");
CAFFE_ENFORCE_LT(axis_, data.dim(), "Axis out of range");
Tensor* output = Output(0, data.sizes().vec(), at::dtype(dataType));
output->CopyFrom(data);
char* out = static_cast<char*>(output->raw_mutable_data(dataType));
// Succeed if size of output is zero, which can happen for empty batch which
// would have data dimension size of 0.
// This *must* be done AFTER output->raw_mutable_data() above as that has
// important allocation side effect that we must see.
if (output->numel() == 0) {
return true;
}
const IndexType* idxs = indices.template data<IndexType>();
const char* src_base = static_cast<const char*>(updates.raw_data());
const int64_t outer_dims_product = updates.size_to_dim(axis_);
const int64_t block_size = updates.size_from_dim(axis_ + 1);
const int64_t block_bytesize = block_size * item_bytesize;
const int64_t src_indexing_axis_dim = updates.size(axis_);
const int64_t src_batch_bytesize = updates.size_from_dim(axis_) * item_bytesize;
const int64_t dst_batch_size = data.size_from_dim(axis_) * item_bytesize;
const int64_t N = indices.size(axis_);
check_indexarray_range<IndexType>(idxs, N, src_indexing_axis_dim);
int64_t i = 0;
for (int64_t batch = 0; batch < outer_dims_product; ++batch) {
int64_t i_max = i + N;
for (; i < i_max && i < indices.numel(); ++i) {
auto idx = idxs[i];
auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
auto dst = out + batch * dst_batch_size + (i - i_max + N) * block_bytesize;
context_.CopyItemsSameDevice(dataType, block_size, src, dst);
}
}
return true;
}
INPUT_TAGS(DATA, INDICES, UPDATES);
// Check that indices fall within dimension array size with CAFFE_ENFORCE.
template <typename IndexType>
static void check_indexarray_range(
const IndexType* indices,
int64_t n,
IndexType indexing_axis_dim) {
for (auto i = 0; i < n; ++i) {
auto idx = indices[i];
CAFFE_ENFORCE(
0 <= idx && idx < indexing_axis_dim,
"INDICES element is out of DATA bounds, id=",
idx,
" axis_dim=",
indexing_axis_dim);
}
}
protected:
int axis_;
};
template <class Context>
class LengthsToSegmentIdsOp : public Operator<Context> {
public:

View File

@ -1249,6 +1249,17 @@ class TestCaffe2Backend(unittest.TestCase):
x = torch.tensor([1.0, float('nan'), 2.0])
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
def test_scatter(self):
class ScatterModel(torch.nn.Module):
def forward(self, input, indices, values):
return input.scatter(1, indices, values)
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
self.run_model_test(ScatterModel(), train=False, input=(input, indices, values),
batch_size=BATCH_SIZE, use_gpu=False)
def test_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):

View File

@ -1542,6 +1542,24 @@ def argmin(g, input, dim, keepdim):
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
@parse_args('v', 'i', 'v', 'v')
def scatter(g, self, dim, index, src):
return g.op("Scatter", self, index, src, axis_i=dim)
@parse_args('v', 'i', 'v', 'v')
def scatter_add(g, self, dim, index, src):
if self.type().kind() != "CompleteTensorType":
return _unimplemented("scatter_add", "input size not accesible")
dtype = self.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dims = self.type().sizes()
to_add = torch.zeros(dims)
to_add = g.op("Constant", value_t=to_add)
to_add = scatter(g, to_add, dim, index, src)
return add(g, self, to_add)
def log2(g, self):
_ln2 = 0.693147180559945309
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2])))