#include #include #include #include #include #include #include namespace torch::mtia { constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA; constexpr c10::DeviceIndex kMTIADeviceCount = 2; static thread_local c10::DeviceIndex current_device = 0; static thread_local std::array current_streams = {c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; static int64_t stream_id_gen = 1; static int64_t event_id_gen = 1; static std::array default_streams = { c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { MTIAGuardImpl() = default; explicit MTIAGuardImpl(c10::DeviceType t) { TORCH_INTERNAL_ASSERT(t == kMTIADeviceType); } c10::DeviceType type() const override { return kMTIADeviceType; } c10::Device exchangeDevice(c10::Device d) const override { c10::Device old_device = getDevice(); if (old_device.index() != d.index()) { setDevice(d); } return old_device; } c10::Device getDevice() const override { return c10::Device(kMTIADeviceType, current_device); } void setDevice(c10::Device d) const override { if (getDevice().index() != d.index()) { current_device = d.index(); } } void uncheckedSetDevice(c10::Device d) const noexcept override { (void)d; } c10::Stream getStream(c10::Device d) const noexcept override { return current_streams[d.index()]; } c10::Stream getNewStream(c10::Device d, int priority = 0) const override { (void)priority; return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); } c10::Stream getDefaultStream(c10::Device d) const override { return default_streams[d.index()]; } c10::Stream getStreamFromGlobalPool( c10::Device d, bool isHighPriority = false) const override { return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); } // NB: These do NOT set the current device c10::Stream exchangeStream(c10::Stream s) const noexcept override { c10::Stream old_stream = getStream(s.device()); return old_stream; } c10::DeviceIndex deviceCount() const noexcept override { return kMTIADeviceCount; } void destroyEvent(void* event, const c10::DeviceIndex device_index) const noexcept override { (void)device_index; } void record( void** event, const c10::Stream& stream, const c10::DeviceIndex device_index, const c10::EventFlag flag) const override { TORCH_CHECK( device_index == -1 || device_index == stream.device_index(), "Event device index ", device_index, " does not match recording stream's device index ", stream.device_index(), "."); const auto orig_device = getDevice(); setDevice(stream.device()); if (*event == nullptr) { *event = reinterpret_cast(event_id_gen++); } setDevice(orig_device); } void block(void* event, const c10::Stream& stream) const override { (void)event; (void)stream; } // May be called from any device bool queryEvent(void* event) const override { (void)event; return true; } // Stream-related functions bool queryStream(const c10::Stream& stream) const override { (void)stream; return true; } void synchronizeStream(const c10::Stream& stream) const override { (void)stream; } void recordDataPtrOnStream( const c10::DataPtr& data_ptr, const c10::Stream& stream) const override { (void)data_ptr; (void)stream; } double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index) const override { (void)device_index; uint64_t elapsed_time = 1e6; return (double)(elapsed_time / 1e6); } void synchronizeEvent(void* event) const override { (void)event; } }; struct MTIAHooks : public at::MTIAHooksInterface { explicit MTIAHooks(at::MTIAHooksArgs) {} void init() const override {} bool hasMTIA() const override { return true; } c10::DeviceIndex deviceCount() const override { torch::utils::device_lazy_init(at::kMTIA); return c10::DeviceIndex(2); } void deviceSynchronize(c10::DeviceIndex device_index) const override { torch::utils::device_lazy_init(at::kMTIA); (void)device_index; } std::string showConfig() const override { return "None config"; } c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override { torch::utils::device_lazy_init(at::kMTIA); auto orig_device = current_device; if (current_device != device) { current_device = device; } return orig_device; } c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override { torch::utils::device_lazy_init(at::kMTIA); auto orig_device = current_device; if (current_device != device) { current_device = device; } return orig_device; } c10::Stream getDefaultStream(c10::DeviceIndex device) const override { torch::utils::device_lazy_init(at::kMTIA); return default_streams[device]; } c10::Stream getCurrentStream(c10::DeviceIndex device) const override { torch::utils::device_lazy_init(at::kMTIA); return current_streams[device]; } void setCurrentStream(const c10::Stream& stream) const override { torch::utils::device_lazy_init(at::kMTIA); current_streams[stream.device_index()] = stream; } c10::DeviceIndex getCurrentDevice() const override { torch::utils::device_lazy_init(at::kMTIA); return current_device; } void setCurrentDevice(c10::DeviceIndex device) const override { torch::utils::device_lazy_init(at::kMTIA); if (current_device != device) { current_device = device; } } }; using at::MTIAHooksRegistry; using at::RegistererMTIAHooksRegistry; REGISTER_MTIA_HOOKS(MTIAHooks); C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl); } // namespace torch::mtia