mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 06:48:48 +08:00
Update on "[inductor] fx pass to split matmuls"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in:
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -4389,7 +4389,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
||||
@ -1,191 +1,3 @@
|
||||
#pragma once
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
/*
|
||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||
* constructed lazily when first recorded. It has a device, and this device is
|
||||
* acquired from the first recording stream. Later streams that record the event
|
||||
* must match the same device.
|
||||
*
|
||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||
* another process via inter-process communication(IPC). So it means that
|
||||
* inter-process communication for event handles between different processes is
|
||||
* not available. This could impact some applications that rely on cross-process
|
||||
* synchronization and communication.
|
||||
*/
|
||||
struct TORCH_XPU_API XPUEvent {
|
||||
// Constructors
|
||||
XPUEvent(bool enable_timing = false) noexcept
|
||||
: enable_timing_{enable_timing} {}
|
||||
|
||||
~XPUEvent() {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XPUEvent(const XPUEvent&) = delete;
|
||||
XPUEvent& operator=(const XPUEvent&) = delete;
|
||||
|
||||
XPUEvent(XPUEvent&& other) = default;
|
||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||
|
||||
operator sycl::event&() const {
|
||||
return event();
|
||||
}
|
||||
|
||||
std::optional<at::Device> device() const {
|
||||
if (isCreated()) {
|
||||
return at::Device(at::kXPU, device_index_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isCreated() const {
|
||||
return (event_.get() != nullptr);
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
sycl::event& event() const {
|
||||
return *event_;
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
using namespace sycl::info;
|
||||
if (!isCreated()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return event().get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
}
|
||||
|
||||
void record() {
|
||||
record(getCurrentXPUStream());
|
||||
}
|
||||
|
||||
void recordOnce(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
record(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void record(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
device_index_ = stream.device_index();
|
||||
assignEvent(stream.queue());
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
device_index_ == stream.device_index(),
|
||||
"Event device ",
|
||||
device_index_,
|
||||
" does not match recording stream's device ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
reassignEvent(stream.queue());
|
||||
}
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
|
||||
void block(const XPUStream& stream) {
|
||||
if (isCreated()) {
|
||||
std::vector<sycl::event> event_list{event()};
|
||||
// Make this stream wait until event_ is completed.
|
||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double elapsed_time(const XPUEvent& other) const {
|
||||
TORCH_CHECK(
|
||||
isCreated() && other.isCreated(),
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
query() && other.query(),
|
||||
"Both events must be completed before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
enable_timing_ && other.enable_timing_,
|
||||
"Both events must be created with argument 'enable_timing=True'.");
|
||||
|
||||
#if SYCL_COMPILER_VERSION < 20250000
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
|
||||
using namespace sycl::info::event_profiling;
|
||||
// Block until both of the recorded events are completed.
|
||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||
// Return the eplased time in milliseconds.
|
||||
return 1e-6 *
|
||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||
}
|
||||
|
||||
void synchronize() const {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
event().wait_and_throw();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void assignEvent(sycl::queue& queue) {
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
if (enable_timing_) {
|
||||
event_ = std::make_unique<sycl::event>(
|
||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||
} else {
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
}
|
||||
#else
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
#endif
|
||||
}
|
||||
|
||||
void reassignEvent(sycl::queue& queue) {
|
||||
event_.reset();
|
||||
assignEvent(queue);
|
||||
}
|
||||
|
||||
bool enable_timing_ = false;
|
||||
DeviceIndex device_index_ = -1;
|
||||
// Only need to track the last event, as events in an in-order queue are
|
||||
// executed sequentially.
|
||||
std::unique_ptr<sycl::event> event_;
|
||||
};
|
||||
|
||||
} // namespace at::xpu
|
||||
#include <c10/xpu/XPUEvent.h>
|
||||
|
||||
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -24,6 +24,7 @@ set(C10_XPU_HEADERS
|
||||
XPUCachingAllocator.h
|
||||
XPUDeviceProp.h
|
||||
XPUException.h
|
||||
XPUEvent.h
|
||||
XPUFunctions.h
|
||||
XPUMacros.h
|
||||
XPUStream.h
|
||||
|
||||
@ -926,14 +926,15 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_total -
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_total);
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
178
c10/xpu/XPUEvent.h
Normal file
178
c10/xpu/XPUEvent.h
Normal file
@ -0,0 +1,178 @@
|
||||
#pragma once
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
namespace c10::xpu {
|
||||
|
||||
/*
|
||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||
* constructed lazily when first recorded. It has a device, and this device is
|
||||
* acquired from the first recording stream. Later streams that record the event
|
||||
* must match the same device.
|
||||
*
|
||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||
* another process via inter-process communication(IPC). So it means that
|
||||
* inter-process communication for event handles between different processes is
|
||||
* not available. This could impact some applications that rely on cross-process
|
||||
* synchronization and communication.
|
||||
*/
|
||||
struct XPUEvent {
|
||||
// Constructors
|
||||
XPUEvent(bool enable_timing = false) noexcept
|
||||
: enable_timing_{enable_timing} {}
|
||||
|
||||
~XPUEvent() {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(XPUEvent);
|
||||
|
||||
XPUEvent(XPUEvent&& other) = default;
|
||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||
|
||||
operator sycl::event&() const {
|
||||
return event();
|
||||
}
|
||||
|
||||
std::optional<c10::Device> device() const {
|
||||
if (isCreated()) {
|
||||
return c10::Device(c10::kXPU, device_index_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isCreated() const {
|
||||
return (event_.get() != nullptr);
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
sycl::event& event() const {
|
||||
return *event_;
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
using namespace sycl::info;
|
||||
if (!isCreated()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return event().get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
}
|
||||
|
||||
void record() {
|
||||
record(getCurrentXPUStream());
|
||||
}
|
||||
|
||||
void recordOnce(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
record(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void record(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
device_index_ = stream.device_index();
|
||||
assignEvent(stream.queue());
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
device_index_ == stream.device_index(),
|
||||
"Event device ",
|
||||
device_index_,
|
||||
" does not match recording stream's device ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
reassignEvent(stream.queue());
|
||||
}
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
c10::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
|
||||
void block(const XPUStream& stream) {
|
||||
if (isCreated()) {
|
||||
std::vector<sycl::event> event_list{event()};
|
||||
// Make this stream wait until event_ is completed.
|
||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
c10::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double elapsed_time(const XPUEvent& other) const {
|
||||
TORCH_CHECK(
|
||||
isCreated() && other.isCreated(),
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
query() && other.query(),
|
||||
"Both events must be completed before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
enable_timing_ && other.enable_timing_,
|
||||
"Both events must be created with argument 'enable_timing=True'.");
|
||||
|
||||
using namespace sycl::info::event_profiling;
|
||||
// Block until both of the recorded events are completed.
|
||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||
// Return the eplased time in milliseconds.
|
||||
return 1e-6 *
|
||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||
}
|
||||
|
||||
void synchronize() const {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
event().wait_and_throw();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void assignEvent(sycl::queue& queue) {
|
||||
if (enable_timing_) {
|
||||
event_ = std::make_unique<sycl::event>(
|
||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||
} else {
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
}
|
||||
}
|
||||
|
||||
void reassignEvent(sycl::queue& queue) {
|
||||
event_.reset();
|
||||
assignEvent(queue);
|
||||
}
|
||||
|
||||
bool enable_timing_ = false;
|
||||
c10::DeviceIndex device_index_ = -1;
|
||||
// Only need to track the last event, as events in an in-order queue are
|
||||
// executed sequentially.
|
||||
std::unique_ptr<sycl::event> event_;
|
||||
};
|
||||
|
||||
} // namespace c10::xpu
|
||||
@ -40,7 +40,6 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -570,14 +570,22 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
|
||||
def test_split_mm(self):
|
||||
def func(a, b):
|
||||
a = a * 2
|
||||
b = b * 3
|
||||
mm = torch.mm(a, b)
|
||||
mm = mm * 2
|
||||
return mm
|
||||
|
||||
# Trace with make_fx
|
||||
ref_a = torch.randn(16, 8)
|
||||
ref_b = torch.randn(8, 4)
|
||||
gm = make_fx(func)(ref_a, ref_b)
|
||||
ref_out = func(ref_a, ref_b)
|
||||
def _inps():
|
||||
return torch.randn(16, 8, device=self.device), torch.randn(
|
||||
8, 4, device=self.device
|
||||
)
|
||||
|
||||
inps = _inps()
|
||||
ref_out = func(*inps)
|
||||
|
||||
gm = make_fx(func, tracing_mode="fake")(*inps)
|
||||
|
||||
from torch._inductor.fx_passes.decompose_mm import split_mms
|
||||
|
||||
split_mms(gm, 16, 4)
|
||||
@ -587,15 +595,25 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
4,
|
||||
exactly=True,
|
||||
).run(graph_str)
|
||||
out = gm(ref_a, ref_b)
|
||||
out = gm(*inps)
|
||||
|
||||
self.assertTrue(same(out, ref_out))
|
||||
|
||||
def test_split_mm_noncont(self):
|
||||
# Non contiguous matmuls are not split
|
||||
def func(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
||||
def _inps():
|
||||
return torch.empty_strided((16, 8), (1, 8)), torch.randn(8, 4)
|
||||
|
||||
inps = _inps()
|
||||
|
||||
gm = make_fx(func, tracing_mode="fake")(*inps)
|
||||
from torch._inductor.fx_passes.decompose_mm import split_mms
|
||||
|
||||
ref_a2 = torch.empty_strided((16, 8), (1, 8))
|
||||
ref_b2 = torch.randn(8, 4)
|
||||
gm = make_fx(func)(ref_a2, ref_b2)
|
||||
split_mms(gm, 16, 4)
|
||||
graph_str = str(gm.graph)
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten.mm",
|
||||
1,
|
||||
|
||||
@ -19,6 +19,7 @@ import operator
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -5635,6 +5636,115 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertTrue(same(res11, res12))
|
||||
self.assertTrue(same(res21, res22))
|
||||
|
||||
def test_replay_side_effects_config(self):
|
||||
# Test that replay_side_effects config controls mutation replay
|
||||
def fn(x, lst):
|
||||
lst.append(x + 1)
|
||||
return x * 2
|
||||
|
||||
x = torch.tensor([5.0])
|
||||
|
||||
# Test with replay enabled (default)
|
||||
lst_with_replay = []
|
||||
opt_fn_with_replay = torch.compile(fn, backend="eager")
|
||||
result1 = opt_fn_with_replay(x, lst_with_replay)
|
||||
self.assertEqual(len(lst_with_replay), 1) # Mutation should be replayed
|
||||
self.assertTrue(same(result1, x * 2))
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Test with replay disabled
|
||||
lst_without_replay = []
|
||||
with torch._dynamo.config.patch(
|
||||
replay_side_effects=False, side_effect_replay_policy="warn"
|
||||
):
|
||||
opt_fn_without_replay = torch.compile(fn, backend="eager")
|
||||
result2 = opt_fn_without_replay(x, lst_without_replay)
|
||||
self.assertEqual(
|
||||
len(lst_without_replay), 0
|
||||
) # Mutation should NOT be replayed
|
||||
self.assertTrue(same(result2, x * 2))
|
||||
|
||||
torch._dynamo.reset()
|
||||
lst_without_replay = []
|
||||
with torch._dynamo.config.patch(
|
||||
replay_side_effects=False, side_effect_replay_policy="error"
|
||||
):
|
||||
opt_fn_without_replay = torch.compile(fn, backend="eager")
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['lst']\"]"
|
||||
),
|
||||
):
|
||||
_ = opt_fn_without_replay(x, lst_without_replay)
|
||||
|
||||
def test_replay_side_effects_model_attr(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.const = 4
|
||||
|
||||
def forward(self, x):
|
||||
return x.cos()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.const = 4
|
||||
self.tensor = None
|
||||
self.bar = Bar()
|
||||
|
||||
def forward(self, x):
|
||||
self.const = 5
|
||||
self.tensor = x.sin()
|
||||
res = self.bar(x)
|
||||
return x.cos() + res.sum() + self.tensor
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
replay_side_effects=False, side_effect_replay_policy="error"
|
||||
):
|
||||
foo = Foo()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['self']\"]"
|
||||
),
|
||||
):
|
||||
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
replay_side_effects=False, side_effect_replay_policy="silent"
|
||||
):
|
||||
foo_v2_compile = Foo()
|
||||
foo_v2_eager = Foo()
|
||||
inp = torch.randn(4, 4)
|
||||
res = torch.compile(foo_v2_compile, fullgraph=True)(torch.randn(4, 4))
|
||||
self.assertEqual(foo_v2_compile.tensor, None)
|
||||
self.assertEqual(foo_v2_compile.const, 4)
|
||||
self.assertEqual(foo_v2_compile.bar.const, 4)
|
||||
same(res, foo_v2_eager(inp))
|
||||
|
||||
def test_replay_side_effects_input_mut(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.const = 4
|
||||
self.tensor = None
|
||||
|
||||
def forward(self, x):
|
||||
x.add_(5)
|
||||
return x.cos()
|
||||
|
||||
# This is ok because we actually capture the graph which
|
||||
# has mutation. In export, we never retrace the actual
|
||||
# gm so we won't see any mutation applied to inputs
|
||||
with torch._dynamo.config.patch(
|
||||
replay_side_effects=False, side_effect_replay_policy="error"
|
||||
):
|
||||
foo = Foo()
|
||||
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
|
||||
|
||||
def test_list_append_return_none(self):
|
||||
def fn(x):
|
||||
alist = []
|
||||
|
||||
@ -349,6 +349,18 @@ def forward(self, x):
|
||||
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
||||
self.assertTrue(torch.allclose(res, res2))
|
||||
|
||||
def test_side_effect(self):
|
||||
global_env = []
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
global_env.append(x)
|
||||
return x.sin()
|
||||
|
||||
with torch._dynamo.config.patch(replay_side_effects=False):
|
||||
_ = dynamo_graph_capture_for_export(Foo())(torch.randn(4, 4))
|
||||
self.assertEqual(len(global_env), 0)
|
||||
|
||||
def test_export_add_in_out_info(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
|
||||
@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
||||
def test_get_memory_info(self):
|
||||
free_bytes, total_bytes = torch.accelerator.get_memory_info()
|
||||
self.assertGreaterEqual(free_bytes, 0)
|
||||
self.assertGreaterEqual(total_bytes, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2674,6 +2674,7 @@ class TestSparse(TestSparseBase):
|
||||
self._test_asin_arcsin(input_uncoalesced, coalesced)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_mv(self, device, dtype, coalesced):
|
||||
|
||||
@ -2501,7 +2501,6 @@ def _accelerator_emptyCache() -> None: ...
|
||||
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
||||
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
||||
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
||||
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
|
||||
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
|
||||
@ -44,6 +44,19 @@ minimum_call_count = 1
|
||||
# turn on/off DCE pass (deprecated: always true)
|
||||
dead_code_elimination = True
|
||||
|
||||
# Enable or disable side effect replay after graph execution.
|
||||
# When False, mutations to Python objects (lists, dicts, attributes) won't be
|
||||
# replayed after the compiled graph runs. This can cause correctness issues
|
||||
# if your code depends on these mutations being visible. This should probably
|
||||
# never be False by default. At the moment, only export will need it.
|
||||
replay_side_effects = True
|
||||
|
||||
# Configure side effect warning level
|
||||
# If `silent`, we silently allow side effects
|
||||
# If `warn`, we warn side effects
|
||||
# If `error`, we error on side effects
|
||||
side_effect_replay_policy = "silent"
|
||||
|
||||
# disable (for a function) when cache reaches this size
|
||||
|
||||
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
|
||||
|
||||
@ -1845,7 +1845,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
[create_instruction("DELETE_FAST", argval=graph_output_var)]
|
||||
)
|
||||
|
||||
if self.export:
|
||||
if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]:
|
||||
from torch.export._trace import _ExportModuleSpecTrackerDict
|
||||
|
||||
potential_side_effects = []
|
||||
@ -1881,10 +1881,16 @@ class OutputGraph(OutputGraphCommon):
|
||||
]
|
||||
|
||||
if side_effect_refs:
|
||||
warnings.warn(
|
||||
f"While exporting, we found certain side effects happened in the model.forward. "
|
||||
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
||||
)
|
||||
if torch._dynamo.config.side_effect_replay_policy == "warn":
|
||||
warnings.warn(
|
||||
f"While compiling, we found certain side effects happened in the model.forward. "
|
||||
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"While compiling, we found certain side effects happened in the model.forward. "
|
||||
f"Here are the list of potential sources you can double check: {side_effect_refs}"
|
||||
)
|
||||
|
||||
return all_stack_locals_metas
|
||||
|
||||
@ -1930,7 +1936,8 @@ class OutputGraph(OutputGraphCommon):
|
||||
assert self.backward_state_var is not None
|
||||
cg.append_output(cg.create_load(self.backward_state_var))
|
||||
cg.store_attr(name)
|
||||
self.side_effects.codegen_hooks(cg)
|
||||
if config.replay_side_effects:
|
||||
self.side_effects.codegen_hooks(cg)
|
||||
|
||||
# TODO get debug_locals working for nested graph breaks
|
||||
# Return variables used for logging at the end
|
||||
@ -1945,7 +1952,8 @@ class OutputGraph(OutputGraphCommon):
|
||||
self.codegen_cells(tx, cg)
|
||||
|
||||
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(cg)
|
||||
if config.replay_side_effects:
|
||||
self.side_effects.codegen_update_mutated(cg)
|
||||
|
||||
def cleanup_graph(self) -> None:
|
||||
"""
|
||||
|
||||
@ -359,6 +359,13 @@ class GraphArg:
|
||||
# stash a strong reference too.
|
||||
example_strong_ref: Optional[torch.Tensor] = None
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception.
|
||||
# This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal
|
||||
# GraphArg creation can be traced, and with replay_side_effects=False,
|
||||
# normal STORE_ATTR bytecode only records mutations without applying them.
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
@property
|
||||
def example(self):
|
||||
if isinstance(self._example, TensorWeakRef):
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
from ._utils import _device_t, _get_device_index
|
||||
from .memory import (
|
||||
empty_cache,
|
||||
get_memory_info,
|
||||
max_memory_allocated,
|
||||
max_memory_reserved,
|
||||
memory_allocated,
|
||||
@ -26,10 +25,9 @@ __all__ = [
|
||||
"current_device_idx", # deprecated
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"empty_cache",
|
||||
"device_count",
|
||||
"device_index",
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"is_available",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
|
||||
@ -8,7 +8,6 @@ from ._utils import _device_t, _get_device_index
|
||||
|
||||
__all__ = [
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
"memory_allocated",
|
||||
@ -88,9 +87,6 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
|
||||
"""
|
||||
if not torch._C._accelerator_isAllocatorInitialized():
|
||||
return OrderedDict()
|
||||
@ -121,9 +117,6 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
||||
|
||||
@ -141,9 +134,6 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
||||
|
||||
@ -157,9 +147,6 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
||||
|
||||
@ -177,9 +164,6 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
||||
|
||||
@ -216,21 +200,3 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_resetPeakStats(device_index)
|
||||
|
||||
|
||||
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
|
||||
r"""Return the current device memory information for a given device index.
|
||||
|
||||
Args:
|
||||
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
|
||||
The first value is the free memory on the device (available across all processes and applications),
|
||||
The second value is the device's total hardware memory capacity.
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_getMemoryInfo(device_index)
|
||||
|
||||
@ -138,13 +138,6 @@ void initModule(PyObject* module) {
|
||||
at::accelerator::resetPeakStats(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::accelerator::getMemoryInfo(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
});
|
||||
|
||||
@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
|
||||
static void initXpuMethodBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
|
||||
auto& device = c10::xpu::get_raw_device(device_index);
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
at::xpu::getDeviceProperties(device_index)->name,
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return std::make_tuple(free, total);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
});
|
||||
m.def(
|
||||
"_xpu_getStreamFromExternal",
|
||||
|
||||
@ -140,6 +140,8 @@ class ExportDynamoConfig:
|
||||
capture_dynamic_output_shape_ops: bool = True
|
||||
capture_scalar_outputs: bool = True
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False
|
||||
replay_side_effects: bool = False
|
||||
side_effect_replay_policy: str = "warn"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
||||
int: the memory available on the device in units of bytes.
|
||||
int: the total memory on the device in units of bytes
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch._C._xpu_getMemoryInfo(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user