Add back at::_copy_from for use by XLA (#20783)

Summary:
XLA needs a way to override CPUTensor.copy_(XLATensor), but we only
dispatch on the "self" argument. This inverts the dispatch order when
"src" is an unhandled type.

Note that things like XLATensor.copy_(CPUTensor) never enter this
implementation.

cc dlibenzi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20783

Differential Revision: D15443187

Pulled By: colesbury

fbshipit-source-id: 4ee93ba598ef0fed2a99c0683aae30cb50a1f99c
This commit is contained in:
Sam Gross
2019-05-23 08:43:41 -07:00
committed by Facebook Github Bot
parent 80aed36fb6
commit f039401bf2
2 changed files with 19 additions and 1 deletions

View File

@ -70,6 +70,13 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
});
}
// Devices directly supported by this copy implementation. Other device types
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
bool is_supported_device(Device device) {
DeviceType device_type = device.type();
return device_type == kCPU || device_type == kCUDA || device_type == kHIP;
}
} // namespace
namespace at {
@ -80,7 +87,6 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
TORCH_CHECK(self.defined(), "self is undefined");
TORCH_CHECK(self.defined(), "src is undefined");
Tensor b_src;
if (self.is_sparse() && src.is_sparse()) {
return at::copy_sparse_to_sparse_(self, src, non_blocking);
} else if (self.is_sparse() || src.is_sparse()) {
@ -92,6 +98,15 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
return self;
}
// Re-dispatch copies when src device not implemented here (e.g. XLA).
// This includes: cpu_tensor.copy_(xla_tensor) which
// calls xla_tensor._copy_from(cpu_tensor)
if (!is_supported_device(src.device())) {
TORCH_INTERNAL_ASSERT(is_supported_device(self.device()));
at::_copy_from(src, self, non_blocking);
return self;
}
if (self.scalar_type() == kQUInt8) {
return quantized_copy_(self, src);
}

View File

@ -453,6 +453,9 @@
variants: method
device_guard: False
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
dispatch: {}
- func: cos(Tensor self) -> Tensor
variants: function, method