mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ONNX Export Scatter
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18543 Differential Revision: D14658639 Pulled By: houseroad fbshipit-source-id: 5d7821b54d2fc93f71120155adf328897d13aff6
This commit is contained in:
committed by
Facebook Github Bot
parent
fea4a56af3
commit
8d7a025703
@ -51,6 +51,7 @@ REGISTER_CPU_OPERATOR(
|
|||||||
ScatterWeightedSum,
|
ScatterWeightedSum,
|
||||||
ScatterWeightedSumOp<float, CPUContext>);
|
ScatterWeightedSumOp<float, CPUContext>);
|
||||||
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
|
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
|
||||||
|
REGISTER_CPU_OPERATOR(Scatter, ScatterOp<CPUContext>);
|
||||||
|
|
||||||
REGISTER_CPU_OPERATOR(LengthsToShape, LengthsToShapeOp<CPUContext>);
|
REGISTER_CPU_OPERATOR(LengthsToShape, LengthsToShapeOp<CPUContext>);
|
||||||
REGISTER_CPU_OPERATOR(HasElements, HasElementsOp<CPUContext>);
|
REGISTER_CPU_OPERATOR(HasElements, HasElementsOp<CPUContext>);
|
||||||
@ -369,6 +370,38 @@ Currently only works on CPU because of access to INDICES.
|
|||||||
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
|
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
|
||||||
.Output(0, "DATA", "Has to be exactly the same tensor as the input 0");
|
.Output(0, "DATA", "Has to be exactly the same tensor as the input 0");
|
||||||
|
|
||||||
|
OPERATOR_SCHEMA(Scatter)
|
||||||
|
.NumInputs(3)
|
||||||
|
.NumOutputs(1)
|
||||||
|
.AllowInplace({{0, 0}})
|
||||||
|
.SetDoc(R"DOC(
|
||||||
|
Update values of the tensor by overriding current value specified by indices.
|
||||||
|
|
||||||
|
Writes all values from the tensor UPDATES into DATA at the indices specified in the INDICES tensor.
|
||||||
|
For each value in DATA, its output index is specified by its index in UPDATES and by the corresponding value in INDICES for the specified axis.
|
||||||
|
|
||||||
|
For a 3-D tensor, DATA is updated as:
|
||||||
|
|
||||||
|
DATA[INDICES[i][j][k]][j][k] = UPDATES[i][j][k] # if axis == 0
|
||||||
|
DATA[i][INDICES[i][j][k]][k] = UPDATES[i][j][k] # if axis == 1
|
||||||
|
DATA[i][j][INDICES[i][j][k]] = UPDATES[i][j][k] # if axis == 2
|
||||||
|
|
||||||
|
Currently only works on CPU because of access to INDICES.
|
||||||
|
)DOC")
|
||||||
|
.Input(0, "DATA", "Tensor to be updated.")
|
||||||
|
.Input(
|
||||||
|
1,
|
||||||
|
"INDICES",
|
||||||
|
"1-D list of indices on the first dimension"
|
||||||
|
"of X_0 that need to be updated")
|
||||||
|
.Input(
|
||||||
|
2,
|
||||||
|
"UPDATES",
|
||||||
|
"Update slices, with shape len(INDICES) + shape(X_0)[1:]")
|
||||||
|
.Output(0, "OUTPUT", "The updated output.")
|
||||||
|
.Arg(
|
||||||
|
"axis",
|
||||||
|
"*(type: int; default: 1)* Which dimension to scatter on.");
|
||||||
|
|
||||||
OPERATOR_SCHEMA(HasElements)
|
OPERATOR_SCHEMA(HasElements)
|
||||||
.NumInputs(1)
|
.NumInputs(1)
|
||||||
@ -739,6 +772,7 @@ REGISTER_GRADIENT(Sum, GetSumGradient);
|
|||||||
|
|
||||||
SHOULD_NOT_DO_GRADIENT(ScatterWeightedSum);
|
SHOULD_NOT_DO_GRADIENT(ScatterWeightedSum);
|
||||||
SHOULD_NOT_DO_GRADIENT(ScatterAssign);
|
SHOULD_NOT_DO_GRADIENT(ScatterAssign);
|
||||||
|
SHOULD_NOT_DO_GRADIENT(Scatter);
|
||||||
|
|
||||||
class GetWeightedSumGradient : public GradientMakerBase {
|
class GetWeightedSumGradient : public GradientMakerBase {
|
||||||
using GradientMakerBase::GradientMakerBase;
|
using GradientMakerBase::GradientMakerBase;
|
||||||
|
@ -738,6 +738,106 @@ class ScatterAssignOp : public Operator<Context> {
|
|||||||
INPUT_TAGS(DATA, INDICES, SLICES);
|
INPUT_TAGS(DATA, INDICES, SLICES);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class ScatterOp : public Operator<CPUContext> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
explicit ScatterOp(Args&&... args)
|
||||||
|
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||||
|
OP_SINGLE_ARG(int, "axis", axis_, 1) {
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~ScatterOp() noexcept override {}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
|
||||||
|
TORCH_CHECK(Context::GetDeviceType() == kCPU, "ScatterOp currently only supports CPU.")
|
||||||
|
|
||||||
|
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
||||||
|
this, this->template Input<Tensor>(INDICES, CPU));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IndexType>
|
||||||
|
bool DoRunWithType() {
|
||||||
|
const Tensor& data = Input(DATA);
|
||||||
|
const Tensor& indices = Input(INDICES);
|
||||||
|
const Tensor& updates = Input(UPDATES);
|
||||||
|
const TypeMeta dataType = data.dtype();
|
||||||
|
size_t item_bytesize = dataType.itemsize();
|
||||||
|
|
||||||
|
// ONNX allows negative axis to index from the back, valid range: [-r, r].
|
||||||
|
axis_ = data.canonical_axis_index(axis_);
|
||||||
|
|
||||||
|
CAFFE_ENFORCE_GE(data.dim(), axis_ + 1, "DATA should be at least [axis+1]-D");
|
||||||
|
CAFFE_ENFORCE_GE(axis_, 0, "Axis should be non-negative");
|
||||||
|
CAFFE_ENFORCE_LT(axis_, data.dim(), "Axis out of range");
|
||||||
|
|
||||||
|
Tensor* output = Output(0, data.sizes().vec(), at::dtype(dataType));
|
||||||
|
output->CopyFrom(data);
|
||||||
|
char* out = static_cast<char*>(output->raw_mutable_data(dataType));
|
||||||
|
|
||||||
|
// Succeed if size of output is zero, which can happen for empty batch which
|
||||||
|
// would have data dimension size of 0.
|
||||||
|
// This *must* be done AFTER output->raw_mutable_data() above as that has
|
||||||
|
// important allocation side effect that we must see.
|
||||||
|
if (output->numel() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const IndexType* idxs = indices.template data<IndexType>();
|
||||||
|
const char* src_base = static_cast<const char*>(updates.raw_data());
|
||||||
|
|
||||||
|
const int64_t outer_dims_product = updates.size_to_dim(axis_);
|
||||||
|
const int64_t block_size = updates.size_from_dim(axis_ + 1);
|
||||||
|
const int64_t block_bytesize = block_size * item_bytesize;
|
||||||
|
|
||||||
|
const int64_t src_indexing_axis_dim = updates.size(axis_);
|
||||||
|
const int64_t src_batch_bytesize = updates.size_from_dim(axis_) * item_bytesize;
|
||||||
|
const int64_t dst_batch_size = data.size_from_dim(axis_) * item_bytesize;
|
||||||
|
|
||||||
|
const int64_t N = indices.size(axis_);
|
||||||
|
|
||||||
|
check_indexarray_range<IndexType>(idxs, N, src_indexing_axis_dim);
|
||||||
|
|
||||||
|
int64_t i = 0;
|
||||||
|
for (int64_t batch = 0; batch < outer_dims_product; ++batch) {
|
||||||
|
int64_t i_max = i + N;
|
||||||
|
for (; i < i_max && i < indices.numel(); ++i) {
|
||||||
|
auto idx = idxs[i];
|
||||||
|
|
||||||
|
auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
|
||||||
|
auto dst = out + batch * dst_batch_size + (i - i_max + N) * block_bytesize;
|
||||||
|
context_.CopyItemsSameDevice(dataType, block_size, src, dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
INPUT_TAGS(DATA, INDICES, UPDATES);
|
||||||
|
|
||||||
|
// Check that indices fall within dimension array size with CAFFE_ENFORCE.
|
||||||
|
template <typename IndexType>
|
||||||
|
static void check_indexarray_range(
|
||||||
|
const IndexType* indices,
|
||||||
|
int64_t n,
|
||||||
|
IndexType indexing_axis_dim) {
|
||||||
|
for (auto i = 0; i < n; ++i) {
|
||||||
|
auto idx = indices[i];
|
||||||
|
CAFFE_ENFORCE(
|
||||||
|
0 <= idx && idx < indexing_axis_dim,
|
||||||
|
"INDICES element is out of DATA bounds, id=",
|
||||||
|
idx,
|
||||||
|
" axis_dim=",
|
||||||
|
indexing_axis_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int axis_;
|
||||||
|
};
|
||||||
|
|
||||||
template <class Context>
|
template <class Context>
|
||||||
class LengthsToSegmentIdsOp : public Operator<Context> {
|
class LengthsToSegmentIdsOp : public Operator<Context> {
|
||||||
public:
|
public:
|
||||||
|
@ -1249,6 +1249,17 @@ class TestCaffe2Backend(unittest.TestCase):
|
|||||||
x = torch.tensor([1.0, float('nan'), 2.0])
|
x = torch.tensor([1.0, float('nan'), 2.0])
|
||||||
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
|
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
|
||||||
|
|
||||||
|
def test_scatter(self):
|
||||||
|
class ScatterModel(torch.nn.Module):
|
||||||
|
def forward(self, input, indices, values):
|
||||||
|
return input.scatter(1, indices, values)
|
||||||
|
|
||||||
|
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
||||||
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
||||||
|
self.run_model_test(ScatterModel(), train=False, input=(input, indices, values),
|
||||||
|
batch_size=BATCH_SIZE, use_gpu=False)
|
||||||
|
|
||||||
def test_flatten(self):
|
def test_flatten(self):
|
||||||
class FlattenModel(torch.nn.Module):
|
class FlattenModel(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
@ -1542,6 +1542,24 @@ def argmin(g, input, dim, keepdim):
|
|||||||
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
|
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args('v', 'i', 'v', 'v')
|
||||||
|
def scatter(g, self, dim, index, src):
|
||||||
|
return g.op("Scatter", self, index, src, axis_i=dim)
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args('v', 'i', 'v', 'v')
|
||||||
|
def scatter_add(g, self, dim, index, src):
|
||||||
|
if self.type().kind() != "CompleteTensorType":
|
||||||
|
return _unimplemented("scatter_add", "input size not accesible")
|
||||||
|
dtype = self.type().scalarType()
|
||||||
|
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
||||||
|
dims = self.type().sizes()
|
||||||
|
to_add = torch.zeros(dims)
|
||||||
|
to_add = g.op("Constant", value_t=to_add)
|
||||||
|
to_add = scatter(g, to_add, dim, index, src)
|
||||||
|
return add(g, self, to_add)
|
||||||
|
|
||||||
|
|
||||||
def log2(g, self):
|
def log2(g, self):
|
||||||
_ln2 = 0.693147180559945309
|
_ln2 = 0.693147180559945309
|
||||||
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2])))
|
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2])))
|
||||||
|
Reference in New Issue
Block a user