Provide a tensor overload to mul_out_sparse_scalar. (#10828)

Summary:
This is a small part of the effort to remove Tensor as a tagged member in Scalar because it is inconsistent with how we normally do overloads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10828

Differential Revision: D9485049

Pulled By: gchanan

fbshipit-source-id: 103f5cc03bb7775cd2d3a0a5c0c5924838055f03
This commit is contained in:
Gregory Chanan
2018-08-24 09:36:29 -07:00
committed by Facebook Github Bot
parent e146518e46
commit 474bd60bad
4 changed files with 32 additions and 10 deletions

View File

@ -50,7 +50,7 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
AT_ERROR("div(): sparse division only supports division by a scalar ",
"(got shape ", other.sizes(), " for argument 'other')");
}
return at::_sparse_div_out(result, self, Scalar(other));
return at::_sparse_div_zerodim_out(result, self, other);
}
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter->device_type(), *iter);

View File

@ -1505,7 +1505,13 @@
CPU: add_out_dense_sparse_cpu
CUDA: add_out_dense_sparse_cuda
- func: _sparse_div_out(Tensor result, Tensor self, Scalar other) -> Tensor
- func: _sparse_div_zerodim_out(Tensor result, Tensor self, Tensor other) -> Tensor
variants: function
dispatch:
SparseCPU: div_out_sparse_zerodim
SparseCUDA: div_out_sparse_zerodim
- func: _sparse_div_scalar_out(Tensor result, Tensor self, Scalar other) -> Tensor
variants: function
dispatch:
SparseCPU: div_out_sparse_scalar
@ -1517,6 +1523,12 @@
SparseCPU: mul_out_sparse_cpu
SparseCUDA: mul_out_sparse_cuda
- func: _sparse_mul_zerodim_out(Tensor result, Tensor self, Tensor other) -> Tensor
variants: function
dispatch:
SparseCPU: mul_out_sparse_zerodim
SparseCUDA: mul_out_sparse_zerodim
- func: _sparse_mul_scalar_out(Tensor result, Tensor self, Scalar other) -> Tensor
variants: function
dispatch:

View File

@ -59,9 +59,10 @@ static Tensor scalar_tensor(Scalar s) {
return tensor;
}
SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) {
AT_ASSERT(r.is_sparse());
AT_ASSERT(t.is_sparse());
AT_ASSERT(value.dim() == 0);
if (isSameTensor(r, t)) {
r._values().mul_(value);
@ -70,13 +71,17 @@ SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scal
r._indices().resize_as_(t._indices());
r._indices().copy_(t._indices());
Tensor r_values = r._values(); // Sigh... needed because mul_out takes Tensor&
at::mul_out(r_values, t._values(), scalar_tensor(value));
at::mul_out(r_values, t._values(), value);
_get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
_get_sparse_impl(r)->set_coalesced(t.is_coalesced());
}
return r;
}
SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
return mul_out_sparse_zerodim(r, t, scalar_tensor(value));
}
// --------------------------------------------------------------------
// log1p(SparseTensor)
// --------------------------------------------------------------------
@ -139,9 +144,10 @@ SparseTensor pow_sparse_scalar(const SparseTensor& t, Scalar value) {
// div(SparseTensor, Scalar)
// --------------------------------------------------------------------
SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) {
AT_ASSERT(r.is_sparse());
AT_ASSERT(t.is_sparse());
AT_ASSERT(value.dim() == 0);
if (isSameTensor(r, t)) {
r._values().div_(value);
@ -150,13 +156,17 @@ SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scal
r._indices().resize_as_(t._indices());
r._indices().copy_(t._indices());
Tensor r_values = r._values(); // Sigh... needed because div_out takes Tensor&
at::div_out(r_values, t._values(), scalar_tensor(value));
at::div_out(r_values, t._values(), value);
_get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
_get_sparse_impl(r)->set_coalesced(t.is_coalesced());
}
return r;
}
SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
return div_out_sparse_zerodim(r, t, scalar_tensor(value));
}
// --------------------------------------------------------------------
// norm(SparseTensor, Scalar)
// --------------------------------------------------------------------
@ -345,9 +355,9 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef
SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor& src_) {
if (src_.dim() == 0) {
return mul_out_sparse_scalar(r, t_, Scalar(src_));
return mul_out_sparse_zerodim(r, t_, src_);
} else if (t_.dim() == 0) {
return mul_out_sparse_scalar(r, src_, Scalar(t_));
return mul_out_sparse_zerodim(r, src_, t_);
}
AT_CHECK(t_.sizes().equals(src_.sizes()), "mul operands have incompatible sizes");

View File

@ -408,9 +408,9 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, const SparseTensor& src_) {
#ifndef __HIP_PLATFORM_HCC__
if (src_.dim() == 0) {
return mul_out_sparse_scalar(r_, t_, Scalar(src_));
return mul_out_sparse_zerodim(r_, t_, src_);
} else if (t_.dim() == 0) {
return mul_out_sparse_scalar(r_, src_, Scalar(t_));
return mul_out_sparse_zerodim(r_, src_, t_);
}
AT_ASSERT(t_.is_cuda()); // dispatch argument