mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ONNX Export Scatter
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18543 Differential Revision: D14658639 Pulled By: houseroad fbshipit-source-id: 5d7821b54d2fc93f71120155adf328897d13aff6
This commit is contained in:
committed by
Facebook Github Bot
parent
fea4a56af3
commit
8d7a025703
@ -51,6 +51,7 @@ REGISTER_CPU_OPERATOR(
|
||||
ScatterWeightedSum,
|
||||
ScatterWeightedSumOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(Scatter, ScatterOp<CPUContext>);
|
||||
|
||||
REGISTER_CPU_OPERATOR(LengthsToShape, LengthsToShapeOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(HasElements, HasElementsOp<CPUContext>);
|
||||
@ -369,6 +370,38 @@ Currently only works on CPU because of access to INDICES.
|
||||
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
|
||||
.Output(0, "DATA", "Has to be exactly the same tensor as the input 0");
|
||||
|
||||
OPERATOR_SCHEMA(Scatter)
|
||||
.NumInputs(3)
|
||||
.NumOutputs(1)
|
||||
.AllowInplace({{0, 0}})
|
||||
.SetDoc(R"DOC(
|
||||
Update values of the tensor by overriding current value specified by indices.
|
||||
|
||||
Writes all values from the tensor UPDATES into DATA at the indices specified in the INDICES tensor.
|
||||
For each value in DATA, its output index is specified by its index in UPDATES and by the corresponding value in INDICES for the specified axis.
|
||||
|
||||
For a 3-D tensor, DATA is updated as:
|
||||
|
||||
DATA[INDICES[i][j][k]][j][k] = UPDATES[i][j][k] # if axis == 0
|
||||
DATA[i][INDICES[i][j][k]][k] = UPDATES[i][j][k] # if axis == 1
|
||||
DATA[i][j][INDICES[i][j][k]] = UPDATES[i][j][k] # if axis == 2
|
||||
|
||||
Currently only works on CPU because of access to INDICES.
|
||||
)DOC")
|
||||
.Input(0, "DATA", "Tensor to be updated.")
|
||||
.Input(
|
||||
1,
|
||||
"INDICES",
|
||||
"1-D list of indices on the first dimension"
|
||||
"of X_0 that need to be updated")
|
||||
.Input(
|
||||
2,
|
||||
"UPDATES",
|
||||
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
|
||||
.Output(0, "OUTPUT", "The updated output.")
|
||||
.Arg(
|
||||
"axis",
|
||||
"*(type: int; default: 1)* Which dimension to scatter on.");
|
||||
|
||||
OPERATOR_SCHEMA(HasElements)
|
||||
.NumInputs(1)
|
||||
@ -739,6 +772,7 @@ REGISTER_GRADIENT(Sum, GetSumGradient);
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(ScatterWeightedSum);
|
||||
SHOULD_NOT_DO_GRADIENT(ScatterAssign);
|
||||
SHOULD_NOT_DO_GRADIENT(Scatter);
|
||||
|
||||
class GetWeightedSumGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
|
@ -738,6 +738,106 @@ class ScatterAssignOp : public Operator<Context> {
|
||||
INPUT_TAGS(DATA, INDICES, SLICES);
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class ScatterOp : public Operator<CPUContext> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
template <class... Args>
|
||||
explicit ScatterOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, 1) {
|
||||
}
|
||||
|
||||
virtual ~ScatterOp() noexcept override {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
||||
TORCH_CHECK(Context::GetDeviceType() == kCPU, "ScatterOp currently only supports CPU.")
|
||||
|
||||
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
||||
this, this->template Input<Tensor>(INDICES, CPU));
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
bool DoRunWithType() {
|
||||
const Tensor& data = Input(DATA);
|
||||
const Tensor& indices = Input(INDICES);
|
||||
const Tensor& updates = Input(UPDATES);
|
||||
const TypeMeta dataType = data.dtype();
|
||||
size_t item_bytesize = dataType.itemsize();
|
||||
|
||||
// ONNX allows negative axis to index from the back, valid range: [-r, r].
|
||||
axis_ = data.canonical_axis_index(axis_);
|
||||
|
||||
CAFFE_ENFORCE_GE(data.dim(), axis_ + 1, "DATA should be at least [axis+1]-D");
|
||||
CAFFE_ENFORCE_GE(axis_, 0, "Axis should be non-negative");
|
||||
CAFFE_ENFORCE_LT(axis_, data.dim(), "Axis out of range");
|
||||
|
||||
Tensor* output = Output(0, data.sizes().vec(), at::dtype(dataType));
|
||||
output->CopyFrom(data);
|
||||
char* out = static_cast<char*>(output->raw_mutable_data(dataType));
|
||||
|
||||
// Succeed if size of output is zero, which can happen for empty batch which
|
||||
// would have data dimension size of 0.
|
||||
// This *must* be done AFTER output->raw_mutable_data() above as that has
|
||||
// important allocation side effect that we must see.
|
||||
if (output->numel() == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const IndexType* idxs = indices.template data<IndexType>();
|
||||
const char* src_base = static_cast<const char*>(updates.raw_data());
|
||||
|
||||
const int64_t outer_dims_product = updates.size_to_dim(axis_);
|
||||
const int64_t block_size = updates.size_from_dim(axis_ + 1);
|
||||
const int64_t block_bytesize = block_size * item_bytesize;
|
||||
|
||||
const int64_t src_indexing_axis_dim = updates.size(axis_);
|
||||
const int64_t src_batch_bytesize = updates.size_from_dim(axis_) * item_bytesize;
|
||||
const int64_t dst_batch_size = data.size_from_dim(axis_) * item_bytesize;
|
||||
|
||||
const int64_t N = indices.size(axis_);
|
||||
|
||||
check_indexarray_range<IndexType>(idxs, N, src_indexing_axis_dim);
|
||||
|
||||
int64_t i = 0;
|
||||
for (int64_t batch = 0; batch < outer_dims_product; ++batch) {
|
||||
int64_t i_max = i + N;
|
||||
for (; i < i_max && i < indices.numel(); ++i) {
|
||||
auto idx = idxs[i];
|
||||
|
||||
auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
|
||||
auto dst = out + batch * dst_batch_size + (i - i_max + N) * block_bytesize;
|
||||
context_.CopyItemsSameDevice(dataType, block_size, src, dst);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
INPUT_TAGS(DATA, INDICES, UPDATES);
|
||||
|
||||
// Check that indices fall within dimension array size with CAFFE_ENFORCE.
|
||||
template <typename IndexType>
|
||||
static void check_indexarray_range(
|
||||
const IndexType* indices,
|
||||
int64_t n,
|
||||
IndexType indexing_axis_dim) {
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
auto idx = indices[i];
|
||||
CAFFE_ENFORCE(
|
||||
0 <= idx && idx < indexing_axis_dim,
|
||||
"INDICES element is out of DATA bounds, id=",
|
||||
idx,
|
||||
" axis_dim=",
|
||||
indexing_axis_dim);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class LengthsToSegmentIdsOp : public Operator<Context> {
|
||||
public:
|
||||
|
@ -1249,6 +1249,17 @@ class TestCaffe2Backend(unittest.TestCase):
|
||||
x = torch.tensor([1.0, float('nan'), 2.0])
|
||||
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
def test_scatter(self):
|
||||
class ScatterModel(torch.nn.Module):
|
||||
def forward(self, input, indices, values):
|
||||
return input.scatter(1, indices, values)
|
||||
|
||||
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
||||
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
||||
self.run_model_test(ScatterModel(), train=False, input=(input, indices, values),
|
||||
batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
def test_flatten(self):
|
||||
class FlattenModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
@ -1542,6 +1542,24 @@ def argmin(g, input, dim, keepdim):
|
||||
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v')
|
||||
def scatter(g, self, dim, index, src):
|
||||
return g.op("Scatter", self, index, src, axis_i=dim)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v')
|
||||
def scatter_add(g, self, dim, index, src):
|
||||
if self.type().kind() != "CompleteTensorType":
|
||||
return _unimplemented("scatter_add", "input size not accesible")
|
||||
dtype = self.type().scalarType()
|
||||
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
||||
dims = self.type().sizes()
|
||||
to_add = torch.zeros(dims)
|
||||
to_add = g.op("Constant", value_t=to_add)
|
||||
to_add = scatter(g, to_add, dim, index, src)
|
||||
return add(g, self, to_add)
|
||||
|
||||
|
||||
def log2(g, self):
|
||||
_ln2 = 0.693147180559945309
|
||||
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2])))
|
||||
|
Reference in New Issue
Block a user