mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add back at::_copy_from for use by XLA (#20783)
Summary: XLA needs a way to override CPUTensor.copy_(XLATensor), but we only dispatch on the "self" argument. This inverts the dispatch order when "src" is an unhandled type. Note that things like XLATensor.copy_(CPUTensor) never enter this implementation. cc dlibenzi Pull Request resolved: https://github.com/pytorch/pytorch/pull/20783 Differential Revision: D15443187 Pulled By: colesbury fbshipit-source-id: 4ee93ba598ef0fed2a99c0683aae30cb50a1f99c
This commit is contained in:
committed by
Facebook Github Bot
parent
80aed36fb6
commit
f039401bf2
@ -70,6 +70,13 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
|
||||
});
|
||||
}
|
||||
|
||||
// Devices directly supported by this copy implementation. Other device types
|
||||
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
|
||||
bool is_supported_device(Device device) {
|
||||
DeviceType device_type = device.type();
|
||||
return device_type == kCPU || device_type == kCUDA || device_type == kHIP;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace at {
|
||||
@ -80,7 +87,6 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
||||
TORCH_CHECK(self.defined(), "self is undefined");
|
||||
TORCH_CHECK(self.defined(), "src is undefined");
|
||||
|
||||
Tensor b_src;
|
||||
if (self.is_sparse() && src.is_sparse()) {
|
||||
return at::copy_sparse_to_sparse_(self, src, non_blocking);
|
||||
} else if (self.is_sparse() || src.is_sparse()) {
|
||||
@ -92,6 +98,15 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
||||
return self;
|
||||
}
|
||||
|
||||
// Re-dispatch copies when src device not implemented here (e.g. XLA).
|
||||
// This includes: cpu_tensor.copy_(xla_tensor) which
|
||||
// calls xla_tensor._copy_from(cpu_tensor)
|
||||
if (!is_supported_device(src.device())) {
|
||||
TORCH_INTERNAL_ASSERT(is_supported_device(self.device()));
|
||||
at::_copy_from(src, self, non_blocking);
|
||||
return self;
|
||||
}
|
||||
|
||||
if (self.scalar_type() == kQUInt8) {
|
||||
return quantized_copy_(self, src);
|
||||
}
|
||||
|
@ -453,6 +453,9 @@
|
||||
variants: method
|
||||
device_guard: False
|
||||
|
||||
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
|
||||
dispatch: {}
|
||||
|
||||
- func: cos(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
|
Reference in New Issue
Block a user