mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Accelerator device and shell hooks (#119329)
This adds a concept of Accelerator that points to one of our devices. See DeviceAccelerator.h in this PR for details https://github.com/pytorch/pytorch/pull/119329/files#diff-83cc748bed5df1a453c272cc5ecc7e572d4eb694c5125384d8fbd17a0b5f50c8 It also adds scaffolding for shared C++ API to allow generic feature implementation. This PR in particular updates the autograd engine to use this generic API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119329 Approved by: https://github.com/ezyang
This commit is contained in:
@ -239,9 +239,13 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
* elements are on different devices (across multiple GPUs, for example)
|
||||
* they may have different streams.
|
||||
*/
|
||||
c10::optional<c10::Stream> stream(const c10::DeviceType device_type) {
|
||||
c10::optional<c10::Stream> stream() {
|
||||
auto opt_device_type = at::getAccelerator();
|
||||
if (!opt_device_type.has_value()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
for (const auto& metadata : input_metadata_) {
|
||||
if (metadata.device().type() == device_type)
|
||||
if (metadata.device().type() == opt_device_type.value())
|
||||
return metadata.stream();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user