mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Update on "[inductor] fx pass to split matmuls"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in:
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
|||||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
|
||||||
c10::DeviceIndex device_index) {
|
|
||||||
const auto device_type = getAccelerator(true).value();
|
|
||||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
|
||||||
}
|
|
||||||
} // namespace at::accelerator
|
} // namespace at::accelerator
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|||||||
@ -4389,7 +4389,7 @@
|
|||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: mv
|
CompositeExplicitAutograd: mv
|
||||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
SparseCPU, SparseCUDA: mv_sparse
|
||||||
|
|
||||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|||||||
@ -1,191 +1,3 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/xpu/XPUContext.h>
|
#include <ATen/xpu/XPUContext.h>
|
||||||
|
#include <c10/xpu/XPUEvent.h>
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
namespace at::xpu {
|
|
||||||
|
|
||||||
/*
|
|
||||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
|
||||||
* constructed lazily when first recorded. It has a device, and this device is
|
|
||||||
* acquired from the first recording stream. Later streams that record the event
|
|
||||||
* must match the same device.
|
|
||||||
*
|
|
||||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
|
||||||
* another process via inter-process communication(IPC). So it means that
|
|
||||||
* inter-process communication for event handles between different processes is
|
|
||||||
* not available. This could impact some applications that rely on cross-process
|
|
||||||
* synchronization and communication.
|
|
||||||
*/
|
|
||||||
struct TORCH_XPU_API XPUEvent {
|
|
||||||
// Constructors
|
|
||||||
XPUEvent(bool enable_timing = false) noexcept
|
|
||||||
: enable_timing_{enable_timing} {}
|
|
||||||
|
|
||||||
~XPUEvent() {
|
|
||||||
if (isCreated()) {
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
||||||
if (C10_UNLIKELY(interp)) {
|
|
||||||
(*interp)->trace_gpu_event_deletion(
|
|
||||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XPUEvent(const XPUEvent&) = delete;
|
|
||||||
XPUEvent& operator=(const XPUEvent&) = delete;
|
|
||||||
|
|
||||||
XPUEvent(XPUEvent&& other) = default;
|
|
||||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
|
||||||
|
|
||||||
operator sycl::event&() const {
|
|
||||||
return event();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<at::Device> device() const {
|
|
||||||
if (isCreated()) {
|
|
||||||
return at::Device(at::kXPU, device_index_);
|
|
||||||
} else {
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isCreated() const {
|
|
||||||
return (event_.get() != nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceIndex device_index() const {
|
|
||||||
return device_index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
sycl::event& event() const {
|
|
||||||
return *event_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool query() const {
|
|
||||||
using namespace sycl::info;
|
|
||||||
if (!isCreated()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return event().get_info<event::command_execution_status>() ==
|
|
||||||
event_command_status::complete;
|
|
||||||
}
|
|
||||||
|
|
||||||
void record() {
|
|
||||||
record(getCurrentXPUStream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void recordOnce(const XPUStream& stream) {
|
|
||||||
if (!isCreated()) {
|
|
||||||
record(stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void record(const XPUStream& stream) {
|
|
||||||
if (!isCreated()) {
|
|
||||||
device_index_ = stream.device_index();
|
|
||||||
assignEvent(stream.queue());
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
||||||
if (C10_UNLIKELY(interp)) {
|
|
||||||
(*interp)->trace_gpu_event_creation(
|
|
||||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(
|
|
||||||
device_index_ == stream.device_index(),
|
|
||||||
"Event device ",
|
|
||||||
device_index_,
|
|
||||||
" does not match recording stream's device ",
|
|
||||||
stream.device_index(),
|
|
||||||
".");
|
|
||||||
reassignEvent(stream.queue());
|
|
||||||
}
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
||||||
if (C10_UNLIKELY(interp)) {
|
|
||||||
(*interp)->trace_gpu_event_record(
|
|
||||||
at::kXPU,
|
|
||||||
reinterpret_cast<uintptr_t>(event_.get()),
|
|
||||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void block(const XPUStream& stream) {
|
|
||||||
if (isCreated()) {
|
|
||||||
std::vector<sycl::event> event_list{event()};
|
|
||||||
// Make this stream wait until event_ is completed.
|
|
||||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
||||||
if (C10_UNLIKELY(interp)) {
|
|
||||||
(*interp)->trace_gpu_event_wait(
|
|
||||||
at::kXPU,
|
|
||||||
reinterpret_cast<uintptr_t>(event_.get()),
|
|
||||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double elapsed_time(const XPUEvent& other) const {
|
|
||||||
TORCH_CHECK(
|
|
||||||
isCreated() && other.isCreated(),
|
|
||||||
"Both events must be recorded before calculating elapsed time.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
query() && other.query(),
|
|
||||||
"Both events must be completed before calculating elapsed time.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
enable_timing_ && other.enable_timing_,
|
|
||||||
"Both events must be created with argument 'enable_timing=True'.");
|
|
||||||
|
|
||||||
#if SYCL_COMPILER_VERSION < 20250000
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
using namespace sycl::info::event_profiling;
|
|
||||||
// Block until both of the recorded events are completed.
|
|
||||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
|
||||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
|
||||||
// Return the eplased time in milliseconds.
|
|
||||||
return 1e-6 *
|
|
||||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
|
||||||
}
|
|
||||||
|
|
||||||
void synchronize() const {
|
|
||||||
if (isCreated()) {
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
||||||
if (C10_UNLIKELY(interp)) {
|
|
||||||
(*interp)->trace_gpu_event_synchronization(
|
|
||||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
|
||||||
}
|
|
||||||
event().wait_and_throw();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void assignEvent(sycl::queue& queue) {
|
|
||||||
#if SYCL_COMPILER_VERSION >= 20250000
|
|
||||||
if (enable_timing_) {
|
|
||||||
event_ = std::make_unique<sycl::event>(
|
|
||||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
|
||||||
} else {
|
|
||||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void reassignEvent(sycl::queue& queue) {
|
|
||||||
event_.reset();
|
|
||||||
assignEvent(queue);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool enable_timing_ = false;
|
|
||||||
DeviceIndex device_index_ = -1;
|
|
||||||
// Only need to track the last event, as events in an in-order queue are
|
|
||||||
// executed sequentially.
|
|
||||||
std::unique_ptr<sycl::event> event_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace at::xpu
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ visformer_small,pass,7
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
|||||||
|
|
||||||
// Resets peak memory usage statistics for the specified device
|
// Resets peak memory usage statistics for the specified device
|
||||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||||
|
|
||||||
// Return the free memory size and total memory size in bytes for the
|
|
||||||
// specified device.
|
|
||||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// This function is used to get the DeviceAllocator for a specific device type
|
// This function is used to get the DeviceAllocator for a specific device type
|
||||||
|
|||||||
@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
|
|||||||
c10::DeviceIndex device,
|
c10::DeviceIndex device,
|
||||||
std::shared_ptr<AllocatorState> pps) = 0;
|
std::shared_ptr<AllocatorState> pps) = 0;
|
||||||
virtual std::string name() = 0;
|
virtual std::string name() = 0;
|
||||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
|
||||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
|
||||||
size_t free = 0;
|
|
||||||
size_t total = 0;
|
|
||||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
|
||||||
return {free, total};
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allocator object, statically initialized
|
// Allocator object, statically initialized
|
||||||
|
|||||||
@ -24,6 +24,7 @@ set(C10_XPU_HEADERS
|
|||||||
XPUCachingAllocator.h
|
XPUCachingAllocator.h
|
||||||
XPUDeviceProp.h
|
XPUDeviceProp.h
|
||||||
XPUException.h
|
XPUException.h
|
||||||
|
XPUEvent.h
|
||||||
XPUFunctions.h
|
XPUFunctions.h
|
||||||
XPUMacros.h
|
XPUMacros.h
|
||||||
XPUStream.h
|
XPUStream.h
|
||||||
|
|||||||
@ -926,14 +926,15 @@ class DeviceCachingAllocator {
|
|||||||
(release_cached_blocks() && alloc_block(params, true));
|
(release_cached_blocks() && alloc_block(params, true));
|
||||||
}
|
}
|
||||||
if (!block_found) {
|
if (!block_found) {
|
||||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
c10::xpu::DeviceProp device_prop;
|
||||||
const auto device_total =
|
c10::xpu::get_device_properties(&device_prop, device);
|
||||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
auto device_total = device_prop.global_mem_size;
|
||||||
// Estimate the available device memory when the SYCL runtime does not
|
// Estimate the available device memory when the SYCL runtime does not
|
||||||
// support the corresponding aspect (ext_intel_free_memory).
|
// support the corresponding aspect (ext_intel_free_memory).
|
||||||
size_t device_free = device_total -
|
size_t device_free = device_prop.global_mem_size -
|
||||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||||
.current;
|
.current;
|
||||||
|
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||||
// affected devices.
|
// affected devices.
|
||||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||||
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<size_t, size_t> getMemoryInfo() {
|
|
||||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
|
||||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
|
||||||
TORCH_CHECK(
|
|
||||||
device.has(sycl::aspect::ext_intel_free_memory),
|
|
||||||
"The device (",
|
|
||||||
device.get_info<sycl::info::device::name>(),
|
|
||||||
") doesn't support querying the available free memory. ",
|
|
||||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
|
||||||
"to help us prioritize its implementation.");
|
|
||||||
const size_t free =
|
|
||||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
|
||||||
return {free, total};
|
|
||||||
}
|
|
||||||
|
|
||||||
double getMemoryFraction() {
|
double getMemoryFraction() {
|
||||||
if (!set_fraction) {
|
if (!set_fraction) {
|
||||||
return 1.0;
|
return 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto device_total =
|
c10::xpu::DeviceProp device_prop;
|
||||||
xpu::get_raw_device(device_index)
|
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||||
.get_info<sycl::info::device::global_mem_size>();
|
|
||||||
return static_cast<double>(allowed_memory_maximum) /
|
return static_cast<double>(allowed_memory_maximum) /
|
||||||
static_cast<double>(device_total);
|
static_cast<double>(device_prop.global_mem_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setMemoryFraction(double fraction) {
|
void setMemoryFraction(double fraction) {
|
||||||
const auto device_total =
|
c10::xpu::DeviceProp device_prop;
|
||||||
xpu::get_raw_device(device_index)
|
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||||
.get_info<sycl::info::device::global_mem_size>();
|
auto device_total = device_prop.global_mem_size;
|
||||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||||
set_fraction = true;
|
set_fraction = true;
|
||||||
}
|
}
|
||||||
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
|
|||||||
c10::xpu::get_raw_device(dev_to_access));
|
c10::xpu::get_raw_device(dev_to_access));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
|
||||||
assertValidDevice(device);
|
|
||||||
return device_allocators[device]->getMemoryInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
double getMemoryFraction(DeviceIndex device) {
|
double getMemoryFraction(DeviceIndex device) {
|
||||||
assertValidDevice(device);
|
assertValidDevice(device);
|
||||||
return device_allocators[device]->getMemoryFraction();
|
return device_allocators[device]->getMemoryFraction();
|
||||||
|
|||||||
178
c10/xpu/XPUEvent.h
Normal file
178
c10/xpu/XPUEvent.h
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <c10/xpu/XPUStream.h>
|
||||||
|
|
||||||
|
namespace c10::xpu {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||||
|
* constructed lazily when first recorded. It has a device, and this device is
|
||||||
|
* acquired from the first recording stream. Later streams that record the event
|
||||||
|
* must match the same device.
|
||||||
|
*
|
||||||
|
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||||
|
* another process via inter-process communication(IPC). So it means that
|
||||||
|
* inter-process communication for event handles between different processes is
|
||||||
|
* not available. This could impact some applications that rely on cross-process
|
||||||
|
* synchronization and communication.
|
||||||
|
*/
|
||||||
|
struct XPUEvent {
|
||||||
|
// Constructors
|
||||||
|
XPUEvent(bool enable_timing = false) noexcept
|
||||||
|
: enable_timing_{enable_timing} {}
|
||||||
|
|
||||||
|
~XPUEvent() {
|
||||||
|
if (isCreated()) {
|
||||||
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
|
if (C10_UNLIKELY(interp)) {
|
||||||
|
(*interp)->trace_gpu_event_deletion(
|
||||||
|
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_DISABLE_COPY_AND_ASSIGN(XPUEvent);
|
||||||
|
|
||||||
|
XPUEvent(XPUEvent&& other) = default;
|
||||||
|
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||||
|
|
||||||
|
operator sycl::event&() const {
|
||||||
|
return event();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<c10::Device> device() const {
|
||||||
|
if (isCreated()) {
|
||||||
|
return c10::Device(c10::kXPU, device_index_);
|
||||||
|
} else {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool isCreated() const {
|
||||||
|
return (event_.get() != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceIndex device_index() const {
|
||||||
|
return device_index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl::event& event() const {
|
||||||
|
return *event_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool query() const {
|
||||||
|
using namespace sycl::info;
|
||||||
|
if (!isCreated()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return event().get_info<event::command_execution_status>() ==
|
||||||
|
event_command_status::complete;
|
||||||
|
}
|
||||||
|
|
||||||
|
void record() {
|
||||||
|
record(getCurrentXPUStream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void recordOnce(const XPUStream& stream) {
|
||||||
|
if (!isCreated()) {
|
||||||
|
record(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void record(const XPUStream& stream) {
|
||||||
|
if (!isCreated()) {
|
||||||
|
device_index_ = stream.device_index();
|
||||||
|
assignEvent(stream.queue());
|
||||||
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
|
if (C10_UNLIKELY(interp)) {
|
||||||
|
(*interp)->trace_gpu_event_creation(
|
||||||
|
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
device_index_ == stream.device_index(),
|
||||||
|
"Event device ",
|
||||||
|
device_index_,
|
||||||
|
" does not match recording stream's device ",
|
||||||
|
stream.device_index(),
|
||||||
|
".");
|
||||||
|
reassignEvent(stream.queue());
|
||||||
|
}
|
||||||
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
|
if (C10_UNLIKELY(interp)) {
|
||||||
|
(*interp)->trace_gpu_event_record(
|
||||||
|
c10::kXPU,
|
||||||
|
reinterpret_cast<uintptr_t>(event_.get()),
|
||||||
|
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block(const XPUStream& stream) {
|
||||||
|
if (isCreated()) {
|
||||||
|
std::vector<sycl::event> event_list{event()};
|
||||||
|
// Make this stream wait until event_ is completed.
|
||||||
|
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||||
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
|
if (C10_UNLIKELY(interp)) {
|
||||||
|
(*interp)->trace_gpu_event_wait(
|
||||||
|
c10::kXPU,
|
||||||
|
reinterpret_cast<uintptr_t>(event_.get()),
|
||||||
|
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double elapsed_time(const XPUEvent& other) const {
|
||||||
|
TORCH_CHECK(
|
||||||
|
isCreated() && other.isCreated(),
|
||||||
|
"Both events must be recorded before calculating elapsed time.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
query() && other.query(),
|
||||||
|
"Both events must be completed before calculating elapsed time.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
enable_timing_ && other.enable_timing_,
|
||||||
|
"Both events must be created with argument 'enable_timing=True'.");
|
||||||
|
|
||||||
|
using namespace sycl::info::event_profiling;
|
||||||
|
// Block until both of the recorded events are completed.
|
||||||
|
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||||
|
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||||
|
// Return the eplased time in milliseconds.
|
||||||
|
return 1e-6 *
|
||||||
|
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
void synchronize() const {
|
||||||
|
if (isCreated()) {
|
||||||
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
|
if (C10_UNLIKELY(interp)) {
|
||||||
|
(*interp)->trace_gpu_event_synchronization(
|
||||||
|
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||||
|
}
|
||||||
|
event().wait_and_throw();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void assignEvent(sycl::queue& queue) {
|
||||||
|
if (enable_timing_) {
|
||||||
|
event_ = std::make_unique<sycl::event>(
|
||||||
|
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||||
|
} else {
|
||||||
|
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void reassignEvent(sycl::queue& queue) {
|
||||||
|
event_.reset();
|
||||||
|
assignEvent(queue);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool enable_timing_ = false;
|
||||||
|
c10::DeviceIndex device_index_ = -1;
|
||||||
|
// Only need to track the last event, as events in an in-order queue are
|
||||||
|
// executed sequentially.
|
||||||
|
std::unique_ptr<sycl::event> event_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10::xpu
|
||||||
@ -40,7 +40,6 @@
|
|||||||
:nosignatures:
|
:nosignatures:
|
||||||
|
|
||||||
empty_cache
|
empty_cache
|
||||||
get_memory_info
|
|
||||||
max_memory_allocated
|
max_memory_allocated
|
||||||
max_memory_reserved
|
max_memory_reserved
|
||||||
memory_allocated
|
memory_allocated
|
||||||
|
|||||||
@ -570,14 +570,22 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
|||||||
|
|
||||||
def test_split_mm(self):
|
def test_split_mm(self):
|
||||||
def func(a, b):
|
def func(a, b):
|
||||||
|
a = a * 2
|
||||||
|
b = b * 3
|
||||||
mm = torch.mm(a, b)
|
mm = torch.mm(a, b)
|
||||||
|
mm = mm * 2
|
||||||
return mm
|
return mm
|
||||||
|
|
||||||
# Trace with make_fx
|
def _inps():
|
||||||
ref_a = torch.randn(16, 8)
|
return torch.randn(16, 8, device=self.device), torch.randn(
|
||||||
ref_b = torch.randn(8, 4)
|
8, 4, device=self.device
|
||||||
gm = make_fx(func)(ref_a, ref_b)
|
)
|
||||||
ref_out = func(ref_a, ref_b)
|
|
||||||
|
inps = _inps()
|
||||||
|
ref_out = func(*inps)
|
||||||
|
|
||||||
|
gm = make_fx(func, tracing_mode="fake")(*inps)
|
||||||
|
|
||||||
from torch._inductor.fx_passes.decompose_mm import split_mms
|
from torch._inductor.fx_passes.decompose_mm import split_mms
|
||||||
|
|
||||||
split_mms(gm, 16, 4)
|
split_mms(gm, 16, 4)
|
||||||
@ -587,15 +595,25 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
|||||||
4,
|
4,
|
||||||
exactly=True,
|
exactly=True,
|
||||||
).run(graph_str)
|
).run(graph_str)
|
||||||
out = gm(ref_a, ref_b)
|
out = gm(*inps)
|
||||||
|
|
||||||
self.assertTrue(same(out, ref_out))
|
self.assertTrue(same(out, ref_out))
|
||||||
|
|
||||||
|
def test_split_mm_noncont(self):
|
||||||
# Non contiguous matmuls are not split
|
# Non contiguous matmuls are not split
|
||||||
|
def func(a, b):
|
||||||
|
return torch.mm(a, b)
|
||||||
|
|
||||||
|
def _inps():
|
||||||
|
return torch.empty_strided((16, 8), (1, 8)), torch.randn(8, 4)
|
||||||
|
|
||||||
|
inps = _inps()
|
||||||
|
|
||||||
|
gm = make_fx(func, tracing_mode="fake")(*inps)
|
||||||
|
from torch._inductor.fx_passes.decompose_mm import split_mms
|
||||||
|
|
||||||
ref_a2 = torch.empty_strided((16, 8), (1, 8))
|
|
||||||
ref_b2 = torch.randn(8, 4)
|
|
||||||
gm = make_fx(func)(ref_a2, ref_b2)
|
|
||||||
split_mms(gm, 16, 4)
|
split_mms(gm, 16, 4)
|
||||||
|
graph_str = str(gm.graph)
|
||||||
FileCheck().check_count(
|
FileCheck().check_count(
|
||||||
"torch.ops.aten.mm",
|
"torch.ops.aten.mm",
|
||||||
1,
|
1,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import operator
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
@ -5635,6 +5636,115 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
self.assertTrue(same(res11, res12))
|
self.assertTrue(same(res11, res12))
|
||||||
self.assertTrue(same(res21, res22))
|
self.assertTrue(same(res21, res22))
|
||||||
|
|
||||||
|
def test_replay_side_effects_config(self):
|
||||||
|
# Test that replay_side_effects config controls mutation replay
|
||||||
|
def fn(x, lst):
|
||||||
|
lst.append(x + 1)
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
x = torch.tensor([5.0])
|
||||||
|
|
||||||
|
# Test with replay enabled (default)
|
||||||
|
lst_with_replay = []
|
||||||
|
opt_fn_with_replay = torch.compile(fn, backend="eager")
|
||||||
|
result1 = opt_fn_with_replay(x, lst_with_replay)
|
||||||
|
self.assertEqual(len(lst_with_replay), 1) # Mutation should be replayed
|
||||||
|
self.assertTrue(same(result1, x * 2))
|
||||||
|
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
# Test with replay disabled
|
||||||
|
lst_without_replay = []
|
||||||
|
with torch._dynamo.config.patch(
|
||||||
|
replay_side_effects=False, side_effect_replay_policy="warn"
|
||||||
|
):
|
||||||
|
opt_fn_without_replay = torch.compile(fn, backend="eager")
|
||||||
|
result2 = opt_fn_without_replay(x, lst_without_replay)
|
||||||
|
self.assertEqual(
|
||||||
|
len(lst_without_replay), 0
|
||||||
|
) # Mutation should NOT be replayed
|
||||||
|
self.assertTrue(same(result2, x * 2))
|
||||||
|
|
||||||
|
torch._dynamo.reset()
|
||||||
|
lst_without_replay = []
|
||||||
|
with torch._dynamo.config.patch(
|
||||||
|
replay_side_effects=False, side_effect_replay_policy="error"
|
||||||
|
):
|
||||||
|
opt_fn_without_replay = torch.compile(fn, backend="eager")
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['lst']\"]"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
_ = opt_fn_without_replay(x, lst_without_replay)
|
||||||
|
|
||||||
|
def test_replay_side_effects_model_attr(self):
|
||||||
|
class Bar(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.const = 4
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.const = 4
|
||||||
|
self.tensor = None
|
||||||
|
self.bar = Bar()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.const = 5
|
||||||
|
self.tensor = x.sin()
|
||||||
|
res = self.bar(x)
|
||||||
|
return x.cos() + res.sum() + self.tensor
|
||||||
|
|
||||||
|
with torch._dynamo.config.patch(
|
||||||
|
replay_side_effects=False, side_effect_replay_policy="error"
|
||||||
|
):
|
||||||
|
foo = Foo()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['self']\"]"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
|
||||||
|
|
||||||
|
with torch._dynamo.config.patch(
|
||||||
|
replay_side_effects=False, side_effect_replay_policy="silent"
|
||||||
|
):
|
||||||
|
foo_v2_compile = Foo()
|
||||||
|
foo_v2_eager = Foo()
|
||||||
|
inp = torch.randn(4, 4)
|
||||||
|
res = torch.compile(foo_v2_compile, fullgraph=True)(torch.randn(4, 4))
|
||||||
|
self.assertEqual(foo_v2_compile.tensor, None)
|
||||||
|
self.assertEqual(foo_v2_compile.const, 4)
|
||||||
|
self.assertEqual(foo_v2_compile.bar.const, 4)
|
||||||
|
same(res, foo_v2_eager(inp))
|
||||||
|
|
||||||
|
def test_replay_side_effects_input_mut(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.const = 4
|
||||||
|
self.tensor = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x.add_(5)
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
# This is ok because we actually capture the graph which
|
||||||
|
# has mutation. In export, we never retrace the actual
|
||||||
|
# gm so we won't see any mutation applied to inputs
|
||||||
|
with torch._dynamo.config.patch(
|
||||||
|
replay_side_effects=False, side_effect_replay_policy="error"
|
||||||
|
):
|
||||||
|
foo = Foo()
|
||||||
|
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
|
||||||
|
|
||||||
def test_list_append_return_none(self):
|
def test_list_append_return_none(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
alist = []
|
alist = []
|
||||||
|
|||||||
@ -349,6 +349,18 @@ def forward(self, x):
|
|||||||
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
||||||
self.assertTrue(torch.allclose(res, res2))
|
self.assertTrue(torch.allclose(res, res2))
|
||||||
|
|
||||||
|
def test_side_effect(self):
|
||||||
|
global_env = []
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
global_env.append(x)
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
with torch._dynamo.config.patch(replay_side_effects=False):
|
||||||
|
_ = dynamo_graph_capture_for_export(Foo())(torch.randn(4, 4))
|
||||||
|
self.assertEqual(len(global_env), 0)
|
||||||
|
|
||||||
def test_export_add_in_out_info(self):
|
def test_export_add_in_out_info(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, dct, lst, bleh):
|
def forward(self, dct, lst, bleh):
|
||||||
|
|||||||
@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
|
|||||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||||
|
|
||||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
|
||||||
def test_get_memory_info(self):
|
|
||||||
free_bytes, total_bytes = torch.accelerator.get_memory_info()
|
|
||||||
self.assertGreaterEqual(free_bytes, 0)
|
|
||||||
self.assertGreaterEqual(total_bytes, 0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -2674,6 +2674,7 @@ class TestSparse(TestSparseBase):
|
|||||||
self._test_asin_arcsin(input_uncoalesced, coalesced)
|
self._test_asin_arcsin(input_uncoalesced, coalesced)
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
|
@expectedFailureMPS
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
@dtypesIfMPS(torch.float32)
|
@dtypesIfMPS(torch.float32)
|
||||||
def test_mv(self, device, dtype, coalesced):
|
def test_mv(self, device, dtype, coalesced):
|
||||||
|
|||||||
@ -2501,7 +2501,6 @@ def _accelerator_emptyCache() -> None: ...
|
|||||||
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
||||||
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
||||||
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
||||||
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
|
|
||||||
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||||
|
|||||||
@ -44,6 +44,19 @@ minimum_call_count = 1
|
|||||||
# turn on/off DCE pass (deprecated: always true)
|
# turn on/off DCE pass (deprecated: always true)
|
||||||
dead_code_elimination = True
|
dead_code_elimination = True
|
||||||
|
|
||||||
|
# Enable or disable side effect replay after graph execution.
|
||||||
|
# When False, mutations to Python objects (lists, dicts, attributes) won't be
|
||||||
|
# replayed after the compiled graph runs. This can cause correctness issues
|
||||||
|
# if your code depends on these mutations being visible. This should probably
|
||||||
|
# never be False by default. At the moment, only export will need it.
|
||||||
|
replay_side_effects = True
|
||||||
|
|
||||||
|
# Configure side effect warning level
|
||||||
|
# If `silent`, we silently allow side effects
|
||||||
|
# If `warn`, we warn side effects
|
||||||
|
# If `error`, we error on side effects
|
||||||
|
side_effect_replay_policy = "silent"
|
||||||
|
|
||||||
# disable (for a function) when cache reaches this size
|
# disable (for a function) when cache reaches this size
|
||||||
|
|
||||||
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
|
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
|
||||||
|
|||||||
@ -1845,7 +1845,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
[create_instruction("DELETE_FAST", argval=graph_output_var)]
|
[create_instruction("DELETE_FAST", argval=graph_output_var)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.export:
|
if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]:
|
||||||
from torch.export._trace import _ExportModuleSpecTrackerDict
|
from torch.export._trace import _ExportModuleSpecTrackerDict
|
||||||
|
|
||||||
potential_side_effects = []
|
potential_side_effects = []
|
||||||
@ -1881,10 +1881,16 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if side_effect_refs:
|
if side_effect_refs:
|
||||||
warnings.warn(
|
if torch._dynamo.config.side_effect_replay_policy == "warn":
|
||||||
f"While exporting, we found certain side effects happened in the model.forward. "
|
warnings.warn(
|
||||||
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
f"While compiling, we found certain side effects happened in the model.forward. "
|
||||||
)
|
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"While compiling, we found certain side effects happened in the model.forward. "
|
||||||
|
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
||||||
|
)
|
||||||
|
|
||||||
return all_stack_locals_metas
|
return all_stack_locals_metas
|
||||||
|
|
||||||
@ -1930,7 +1936,8 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
assert self.backward_state_var is not None
|
assert self.backward_state_var is not None
|
||||||
cg.append_output(cg.create_load(self.backward_state_var))
|
cg.append_output(cg.create_load(self.backward_state_var))
|
||||||
cg.store_attr(name)
|
cg.store_attr(name)
|
||||||
self.side_effects.codegen_hooks(cg)
|
if config.replay_side_effects:
|
||||||
|
self.side_effects.codegen_hooks(cg)
|
||||||
|
|
||||||
# TODO get debug_locals working for nested graph breaks
|
# TODO get debug_locals working for nested graph breaks
|
||||||
# Return variables used for logging at the end
|
# Return variables used for logging at the end
|
||||||
@ -1945,7 +1952,8 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
self.codegen_cells(tx, cg)
|
self.codegen_cells(tx, cg)
|
||||||
|
|
||||||
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
||||||
self.side_effects.codegen_update_mutated(cg)
|
if config.replay_side_effects:
|
||||||
|
self.side_effects.codegen_update_mutated(cg)
|
||||||
|
|
||||||
def cleanup_graph(self) -> None:
|
def cleanup_graph(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -359,6 +359,13 @@ class GraphArg:
|
|||||||
# stash a strong reference too.
|
# stash a strong reference too.
|
||||||
example_strong_ref: Optional[torch.Tensor] = None
|
example_strong_ref: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
# Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception.
|
||||||
|
# This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal
|
||||||
|
# GraphArg creation can be traced, and with replay_side_effects=False,
|
||||||
|
# normal STORE_ATTR bytecode only records mutations without applying them.
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def example(self):
|
def example(self):
|
||||||
if isinstance(self._example, TensorWeakRef):
|
if isinstance(self._example, TensorWeakRef):
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
from ._utils import _device_t, _get_device_index
|
from ._utils import _device_t, _get_device_index
|
||||||
from .memory import (
|
from .memory import (
|
||||||
empty_cache,
|
empty_cache,
|
||||||
get_memory_info,
|
|
||||||
max_memory_allocated,
|
max_memory_allocated,
|
||||||
max_memory_reserved,
|
max_memory_reserved,
|
||||||
memory_allocated,
|
memory_allocated,
|
||||||
@ -26,10 +25,9 @@ __all__ = [
|
|||||||
"current_device_idx", # deprecated
|
"current_device_idx", # deprecated
|
||||||
"current_device_index",
|
"current_device_index",
|
||||||
"current_stream",
|
"current_stream",
|
||||||
|
"empty_cache",
|
||||||
"device_count",
|
"device_count",
|
||||||
"device_index",
|
"device_index",
|
||||||
"empty_cache",
|
|
||||||
"get_memory_info",
|
|
||||||
"is_available",
|
"is_available",
|
||||||
"max_memory_allocated",
|
"max_memory_allocated",
|
||||||
"max_memory_reserved",
|
"max_memory_reserved",
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from ._utils import _device_t, _get_device_index
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"empty_cache",
|
"empty_cache",
|
||||||
"get_memory_info",
|
|
||||||
"max_memory_allocated",
|
"max_memory_allocated",
|
||||||
"max_memory_reserved",
|
"max_memory_reserved",
|
||||||
"memory_allocated",
|
"memory_allocated",
|
||||||
@ -88,9 +87,6 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
|||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
If a :class:`torch.device` or str is provided, its type must match the current
|
||||||
:ref:`accelerator<accelerators>` device type.
|
:ref:`accelerator<accelerators>` device type.
|
||||||
|
|
||||||
Returns:
|
|
||||||
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
|
|
||||||
"""
|
"""
|
||||||
if not torch._C._accelerator_isAllocatorInitialized():
|
if not torch._C._accelerator_isAllocatorInitialized():
|
||||||
return OrderedDict()
|
return OrderedDict()
|
||||||
@ -121,9 +117,6 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
|
|||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
If a :class:`torch.device` or str is provided, its type must match the current
|
||||||
:ref:`accelerator<accelerators>` device type.
|
:ref:`accelerator<accelerators>` device type.
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the current memory occupied by live tensors (in bytes) within the current process.
|
|
||||||
"""
|
"""
|
||||||
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
||||||
|
|
||||||
@ -141,9 +134,6 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
|
|||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
If a :class:`torch.device` or str is provided, its type must match the current
|
||||||
:ref:`accelerator<accelerators>` device type.
|
:ref:`accelerator<accelerators>` device type.
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the peak memory occupied by live tensors (in bytes) within the current process.
|
|
||||||
"""
|
"""
|
||||||
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
||||||
|
|
||||||
@ -157,9 +147,6 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
|
|||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
If a :class:`torch.device` or str is provided, its type must match the current
|
||||||
:ref:`accelerator<accelerators>` device type.
|
:ref:`accelerator<accelerators>` device type.
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the current memory reserved by PyTorch (in bytes) within the current process.
|
|
||||||
"""
|
"""
|
||||||
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
||||||
|
|
||||||
@ -177,9 +164,6 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
|
|||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
If a :class:`torch.device` or str is provided, its type must match the current
|
||||||
:ref:`accelerator<accelerators>` device type.
|
:ref:`accelerator<accelerators>` device type.
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the peak memory reserved by PyTorch (in bytes) within the current process.
|
|
||||||
"""
|
"""
|
||||||
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
||||||
|
|
||||||
@ -216,21 +200,3 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
|
|||||||
"""
|
"""
|
||||||
device_index = _get_device_index(device_index, optional=True)
|
device_index = _get_device_index(device_index, optional=True)
|
||||||
return torch._C._accelerator_resetPeakStats(device_index)
|
return torch._C._accelerator_resetPeakStats(device_index)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
|
|
||||||
r"""Return the current device memory information for a given device index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
|
|
||||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
|
||||||
If a :class:`torch.device` or str is provided, its type must match the current
|
|
||||||
:ref:`accelerator<accelerators>` device type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
|
|
||||||
The first value is the free memory on the device (available across all processes and applications),
|
|
||||||
The second value is the device's total hardware memory capacity.
|
|
||||||
"""
|
|
||||||
device_index = _get_device_index(device_index, optional=True)
|
|
||||||
return torch._C._accelerator_getMemoryInfo(device_index)
|
|
||||||
|
|||||||
@ -138,13 +138,6 @@ void initModule(PyObject* module) {
|
|||||||
at::accelerator::resetPeakStats(device_index);
|
at::accelerator::resetPeakStats(device_index);
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
|
||||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
||||||
torch::utils::maybe_initialize_device(device_type);
|
|
||||||
py::gil_scoped_release no_gil;
|
|
||||||
return at::accelerator::getMemoryInfo(device_index);
|
|
||||||
});
|
|
||||||
|
|
||||||
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
||||||
c10::CachingAllocator::setAllocatorSettings(env);
|
c10::CachingAllocator::setAllocatorSettings(env);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
|
|||||||
static void initXpuMethodBindings(PyObject* module) {
|
static void initXpuMethodBindings(PyObject* module) {
|
||||||
auto m = py::handle(module).cast<py::module>();
|
auto m = py::handle(module).cast<py::module>();
|
||||||
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||||
py::gil_scoped_release no_gil;
|
#if SYCL_COMPILER_VERSION >= 20250000
|
||||||
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
|
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
|
||||||
|
auto& device = c10::xpu::get_raw_device(device_index);
|
||||||
|
TORCH_CHECK(
|
||||||
|
device.has(sycl::aspect::ext_intel_free_memory),
|
||||||
|
"The device (",
|
||||||
|
at::xpu::getDeviceProperties(device_index)->name,
|
||||||
|
") doesn't support querying the available free memory. ",
|
||||||
|
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||||
|
"to help us prioritize its implementation.");
|
||||||
|
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||||
|
return std::make_tuple(free, total);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||||
|
#endif
|
||||||
});
|
});
|
||||||
m.def(
|
m.def(
|
||||||
"_xpu_getStreamFromExternal",
|
"_xpu_getStreamFromExternal",
|
||||||
|
|||||||
@ -140,6 +140,8 @@ class ExportDynamoConfig:
|
|||||||
capture_dynamic_output_shape_ops: bool = True
|
capture_dynamic_output_shape_ops: bool = True
|
||||||
capture_scalar_outputs: bool = True
|
capture_scalar_outputs: bool = True
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False
|
prefer_deferred_runtime_asserts_over_guards: bool = False
|
||||||
|
replay_side_effects: bool = False
|
||||||
|
side_effect_replay_policy: str = "warn"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
|||||||
int: the memory available on the device in units of bytes.
|
int: the memory available on the device in units of bytes.
|
||||||
int: the total memory on the device in units of bytes
|
int: the total memory on the device in units of bytes
|
||||||
"""
|
"""
|
||||||
_lazy_init()
|
|
||||||
device = _get_device_index(device, optional=True)
|
device = _get_device_index(device, optional=True)
|
||||||
return torch._C._xpu_getMemoryInfo(device)
|
return torch._C._xpu_getMemoryInfo(device)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user