Update on "[inductor] fx pass to split matmuls"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
This commit is contained in:
IvanKobzarev
2025-11-13 00:59:07 -08:00
25 changed files with 399 additions and 310 deletions

View File

@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -4389,7 +4389,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mv
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
SparseCPU, SparseCUDA: mv_sparse
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:

View File

@ -1,191 +1,3 @@
#pragma once
#include <ATen/xpu/XPUContext.h>
#include <optional>
namespace at::xpu {
/*
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
* constructed lazily when first recorded. It has a device, and this device is
* acquired from the first recording stream. Later streams that record the event
* must match the same device.
*
* Currently, XPUEvent does NOT support to export an inter-process event from
* another process via inter-process communication(IPC). So it means that
* inter-process communication for event handles between different processes is
* not available. This could impact some applications that rely on cross-process
* synchronization and communication.
*/
struct TORCH_XPU_API XPUEvent {
// Constructors
XPUEvent(bool enable_timing = false) noexcept
: enable_timing_{enable_timing} {}
~XPUEvent() {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
}
}
XPUEvent(const XPUEvent&) = delete;
XPUEvent& operator=(const XPUEvent&) = delete;
XPUEvent(XPUEvent&& other) = default;
XPUEvent& operator=(XPUEvent&& other) = default;
operator sycl::event&() const {
return event();
}
std::optional<at::Device> device() const {
if (isCreated()) {
return at::Device(at::kXPU, device_index_);
} else {
return std::nullopt;
}
}
inline bool isCreated() const {
return (event_.get() != nullptr);
}
DeviceIndex device_index() const {
return device_index_;
}
sycl::event& event() const {
return *event_;
}
bool query() const {
using namespace sycl::info;
if (!isCreated()) {
return true;
}
return event().get_info<event::command_execution_status>() ==
event_command_status::complete;
}
void record() {
record(getCurrentXPUStream());
}
void recordOnce(const XPUStream& stream) {
if (!isCreated()) {
record(stream);
}
}
void record(const XPUStream& stream) {
if (!isCreated()) {
device_index_ = stream.device_index();
assignEvent(stream.queue());
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
} else {
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
reassignEvent(stream.queue());
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
void block(const XPUStream& stream) {
if (isCreated()) {
std::vector<sycl::event> event_list{event()};
// Make this stream wait until event_ is completed.
stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
}
double elapsed_time(const XPUEvent& other) const {
TORCH_CHECK(
isCreated() && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
TORCH_CHECK(
enable_timing_ && other.enable_timing_,
"Both events must be created with argument 'enable_timing=True'.");
#if SYCL_COMPILER_VERSION < 20250000
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
using namespace sycl::info::event_profiling;
// Block until both of the recorded events are completed.
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
uint64_t start_time_ns = event().get_profiling_info<command_end>();
// Return the eplased time in milliseconds.
return 1e-6 *
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
}
void synchronize() const {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
event().wait_and_throw();
}
}
private:
void assignEvent(sycl::queue& queue) {
#if SYCL_COMPILER_VERSION >= 20250000
if (enable_timing_) {
event_ = std::make_unique<sycl::event>(
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
} else {
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
}
#else
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
#endif
}
void reassignEvent(sycl::queue& queue) {
event_.reset();
assignEvent(queue);
}
bool enable_timing_ = false;
DeviceIndex device_index_ = -1;
// Only need to track the last event, as events in an in-order queue are
// executed sequentially.
std::unique_ptr<sycl::event> event_;
};
} // namespace at::xpu
#include <c10/xpu/XPUEvent.h>

View File

@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
@ -66,7 +66,7 @@ visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7

1 name accuracy graph_breaks
10 mobilenetv2_100 pass 7
11 mobilenetv3_large_100 pass 7
12 mobilevit_s pass 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
66
67
68
69
70
71
72

View File

@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
c10::DeviceGuard device_guard({at::kCUDA, device});
size_t free = 0;
size_t total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
return {free, total};
}
};
// Allocator object, statically initialized

View File

@ -24,6 +24,7 @@ set(C10_XPU_HEADERS
XPUCachingAllocator.h
XPUDeviceProp.h
XPUException.h
XPUEvent.h
XPUFunctions.h
XPUMacros.h
XPUStream.h

View File

@ -926,14 +926,15 @@ class DeviceCachingAllocator {
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_total -
size_t device_free = device_prop.global_mem_size -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
c10::xpu::get_raw_device(dev_to_access));
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();

178
c10/xpu/XPUEvent.h Normal file
View File

@ -0,0 +1,178 @@
#pragma once
#include <c10/xpu/XPUStream.h>
namespace c10::xpu {
/*
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
* constructed lazily when first recorded. It has a device, and this device is
* acquired from the first recording stream. Later streams that record the event
* must match the same device.
*
* Currently, XPUEvent does NOT support to export an inter-process event from
* another process via inter-process communication(IPC). So it means that
* inter-process communication for event handles between different processes is
* not available. This could impact some applications that rely on cross-process
* synchronization and communication.
*/
struct XPUEvent {
// Constructors
XPUEvent(bool enable_timing = false) noexcept
: enable_timing_{enable_timing} {}
~XPUEvent() {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
}
}
C10_DISABLE_COPY_AND_ASSIGN(XPUEvent);
XPUEvent(XPUEvent&& other) = default;
XPUEvent& operator=(XPUEvent&& other) = default;
operator sycl::event&() const {
return event();
}
std::optional<c10::Device> device() const {
if (isCreated()) {
return c10::Device(c10::kXPU, device_index_);
} else {
return std::nullopt;
}
}
inline bool isCreated() const {
return (event_.get() != nullptr);
}
DeviceIndex device_index() const {
return device_index_;
}
sycl::event& event() const {
return *event_;
}
bool query() const {
using namespace sycl::info;
if (!isCreated()) {
return true;
}
return event().get_info<event::command_execution_status>() ==
event_command_status::complete;
}
void record() {
record(getCurrentXPUStream());
}
void recordOnce(const XPUStream& stream) {
if (!isCreated()) {
record(stream);
}
}
void record(const XPUStream& stream) {
if (!isCreated()) {
device_index_ = stream.device_index();
assignEvent(stream.queue());
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
} else {
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
reassignEvent(stream.queue());
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
void block(const XPUStream& stream) {
if (isCreated()) {
std::vector<sycl::event> event_list{event()};
// Make this stream wait until event_ is completed.
stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
}
double elapsed_time(const XPUEvent& other) const {
TORCH_CHECK(
isCreated() && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
TORCH_CHECK(
enable_timing_ && other.enable_timing_,
"Both events must be created with argument 'enable_timing=True'.");
using namespace sycl::info::event_profiling;
// Block until both of the recorded events are completed.
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
uint64_t start_time_ns = event().get_profiling_info<command_end>();
// Return the eplased time in milliseconds.
return 1e-6 *
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
}
void synchronize() const {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
event().wait_and_throw();
}
}
private:
void assignEvent(sycl::queue& queue) {
if (enable_timing_) {
event_ = std::make_unique<sycl::event>(
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
} else {
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
}
}
void reassignEvent(sycl::queue& queue) {
event_.reset();
assignEvent(queue);
}
bool enable_timing_ = false;
c10::DeviceIndex device_index_ = -1;
// Only need to track the last event, as events in an in-order queue are
// executed sequentially.
std::unique_ptr<sycl::event> event_;
};
} // namespace c10::xpu

View File

@ -40,7 +40,6 @@
:nosignatures:
empty_cache
get_memory_info
max_memory_allocated
max_memory_reserved
memory_allocated

View File

@ -570,14 +570,22 @@ class TestOverlapPreservingBucketing(InductorTestCase):
def test_split_mm(self):
def func(a, b):
a = a * 2
b = b * 3
mm = torch.mm(a, b)
mm = mm * 2
return mm
# Trace with make_fx
ref_a = torch.randn(16, 8)
ref_b = torch.randn(8, 4)
gm = make_fx(func)(ref_a, ref_b)
ref_out = func(ref_a, ref_b)
def _inps():
return torch.randn(16, 8, device=self.device), torch.randn(
8, 4, device=self.device
)
inps = _inps()
ref_out = func(*inps)
gm = make_fx(func, tracing_mode="fake")(*inps)
from torch._inductor.fx_passes.decompose_mm import split_mms
split_mms(gm, 16, 4)
@ -587,15 +595,25 @@ class TestOverlapPreservingBucketing(InductorTestCase):
4,
exactly=True,
).run(graph_str)
out = gm(ref_a, ref_b)
out = gm(*inps)
self.assertTrue(same(out, ref_out))
def test_split_mm_noncont(self):
# Non contiguous matmuls are not split
def func(a, b):
return torch.mm(a, b)
def _inps():
return torch.empty_strided((16, 8), (1, 8)), torch.randn(8, 4)
inps = _inps()
gm = make_fx(func, tracing_mode="fake")(*inps)
from torch._inductor.fx_passes.decompose_mm import split_mms
ref_a2 = torch.empty_strided((16, 8), (1, 8))
ref_b2 = torch.randn(8, 4)
gm = make_fx(func)(ref_a2, ref_b2)
split_mms(gm, 16, 4)
graph_str = str(gm.graph)
FileCheck().check_count(
"torch.ops.aten.mm",
1,

View File

@ -19,6 +19,7 @@ import operator
import os
import pickle
import random
import re
import sys
import tempfile
import threading
@ -5635,6 +5636,115 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertTrue(same(res11, res12))
self.assertTrue(same(res21, res22))
def test_replay_side_effects_config(self):
# Test that replay_side_effects config controls mutation replay
def fn(x, lst):
lst.append(x + 1)
return x * 2
x = torch.tensor([5.0])
# Test with replay enabled (default)
lst_with_replay = []
opt_fn_with_replay = torch.compile(fn, backend="eager")
result1 = opt_fn_with_replay(x, lst_with_replay)
self.assertEqual(len(lst_with_replay), 1) # Mutation should be replayed
self.assertTrue(same(result1, x * 2))
torch._dynamo.reset()
# Test with replay disabled
lst_without_replay = []
with torch._dynamo.config.patch(
replay_side_effects=False, side_effect_replay_policy="warn"
):
opt_fn_without_replay = torch.compile(fn, backend="eager")
result2 = opt_fn_without_replay(x, lst_without_replay)
self.assertEqual(
len(lst_without_replay), 0
) # Mutation should NOT be replayed
self.assertTrue(same(result2, x * 2))
torch._dynamo.reset()
lst_without_replay = []
with torch._dynamo.config.patch(
replay_side_effects=False, side_effect_replay_policy="error"
):
opt_fn_without_replay = torch.compile(fn, backend="eager")
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['lst']\"]"
),
):
_ = opt_fn_without_replay(x, lst_without_replay)
def test_replay_side_effects_model_attr(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = 4
def forward(self, x):
return x.cos()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = 4
self.tensor = None
self.bar = Bar()
def forward(self, x):
self.const = 5
self.tensor = x.sin()
res = self.bar(x)
return x.cos() + res.sum() + self.tensor
with torch._dynamo.config.patch(
replay_side_effects=False, side_effect_replay_policy="error"
):
foo = Foo()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['self']\"]"
),
):
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
with torch._dynamo.config.patch(
replay_side_effects=False, side_effect_replay_policy="silent"
):
foo_v2_compile = Foo()
foo_v2_eager = Foo()
inp = torch.randn(4, 4)
res = torch.compile(foo_v2_compile, fullgraph=True)(torch.randn(4, 4))
self.assertEqual(foo_v2_compile.tensor, None)
self.assertEqual(foo_v2_compile.const, 4)
self.assertEqual(foo_v2_compile.bar.const, 4)
same(res, foo_v2_eager(inp))
def test_replay_side_effects_input_mut(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = 4
self.tensor = None
def forward(self, x):
x.add_(5)
return x.cos()
# This is ok because we actually capture the graph which
# has mutation. In export, we never retrace the actual
# gm so we won't see any mutation applied to inputs
with torch._dynamo.config.patch(
replay_side_effects=False, side_effect_replay_policy="error"
):
foo = Foo()
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
def test_list_append_return_none(self):
def fn(x):
alist = []

View File

@ -349,6 +349,18 @@ def forward(self, x):
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
self.assertTrue(torch.allclose(res, res2))
def test_side_effect(self):
global_env = []
class Foo(torch.nn.Module):
def forward(self, x):
global_env.append(x)
return x.sin()
with torch._dynamo.config.patch(replay_side_effects=False):
_ = dynamo_graph_capture_for_export(Foo())(torch.randn(4, 4))
self.assertEqual(len(global_env), 0)
def test_export_add_in_out_info(self):
class Foo(torch.nn.Module):
def forward(self, dct, lst, bleh):

View File

@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_get_memory_info(self):
free_bytes, total_bytes = torch.accelerator.get_memory_info()
self.assertGreaterEqual(free_bytes, 0)
self.assertGreaterEqual(total_bytes, 0)
if __name__ == "__main__":
run_tests()

View File

@ -2674,6 +2674,7 @@ class TestSparse(TestSparseBase):
self._test_asin_arcsin(input_uncoalesced, coalesced)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_mv(self, device, dtype, coalesced):

View File

@ -2501,7 +2501,6 @@ def _accelerator_emptyCache() -> None: ...
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
def _accelerator_setAllocatorSettings(env: str) -> None: ...
# Defined in torch/csrc/jit/python/python_tracer.cpp

View File

@ -44,6 +44,19 @@ minimum_call_count = 1
# turn on/off DCE pass (deprecated: always true)
dead_code_elimination = True
# Enable or disable side effect replay after graph execution.
# When False, mutations to Python objects (lists, dicts, attributes) won't be
# replayed after the compiled graph runs. This can cause correctness issues
# if your code depends on these mutations being visible. This should probably
# never be False by default. At the moment, only export will need it.
replay_side_effects = True
# Configure side effect warning level
# If `silent`, we silently allow side effects
# If `warn`, we warn side effects
# If `error`, we error on side effects
side_effect_replay_policy = "silent"
# disable (for a function) when cache reaches this size
# controls the maximum number of cache entries with a guard on same ID_MATCH'd

View File

@ -1845,7 +1845,7 @@ class OutputGraph(OutputGraphCommon):
[create_instruction("DELETE_FAST", argval=graph_output_var)]
)
if self.export:
if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]:
from torch.export._trace import _ExportModuleSpecTrackerDict
potential_side_effects = []
@ -1881,10 +1881,16 @@ class OutputGraph(OutputGraphCommon):
]
if side_effect_refs:
warnings.warn(
f"While exporting, we found certain side effects happened in the model.forward. "
f"Here are the list of potential sources you can double check: {side_effect_refs}"
)
if torch._dynamo.config.side_effect_replay_policy == "warn":
warnings.warn(
f"While compiling, we found certain side effects happened in the model.forward. "
f"Here are the list of potential sources you can double check: {side_effect_refs}"
)
else:
raise RuntimeError(
f"While compiling, we found certain side effects happened in the model.forward. "
f"Here are the list of potential sources you can double check: {side_effect_refs}"
)
return all_stack_locals_metas
@ -1930,7 +1936,8 @@ class OutputGraph(OutputGraphCommon):
assert self.backward_state_var is not None
cg.append_output(cg.create_load(self.backward_state_var))
cg.store_attr(name)
self.side_effects.codegen_hooks(cg)
if config.replay_side_effects:
self.side_effects.codegen_hooks(cg)
# TODO get debug_locals working for nested graph breaks
# Return variables used for logging at the end
@ -1945,7 +1952,8 @@ class OutputGraph(OutputGraphCommon):
self.codegen_cells(tx, cg)
cg.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(cg)
if config.replay_side_effects:
self.side_effects.codegen_update_mutated(cg)
def cleanup_graph(self) -> None:
"""

View File

@ -359,6 +359,13 @@ class GraphArg:
# stash a strong reference too.
example_strong_ref: Optional[torch.Tensor] = None
def __setattr__(self, name, value):
# Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception.
# This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal
# GraphArg creation can be traced, and with replay_side_effects=False,
# normal STORE_ATTR bytecode only records mutations without applying them.
object.__setattr__(self, name, value)
@property
def example(self):
if isinstance(self._example, TensorWeakRef):

View File

@ -10,7 +10,6 @@ import torch
from ._utils import _device_t, _get_device_index
from .memory import (
empty_cache,
get_memory_info,
max_memory_allocated,
max_memory_reserved,
memory_allocated,
@ -26,10 +25,9 @@ __all__ = [
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"empty_cache",
"device_count",
"device_index",
"empty_cache",
"get_memory_info",
"is_available",
"max_memory_allocated",
"max_memory_reserved",

View File

@ -8,7 +8,6 @@ from ._utils import _device_t, _get_device_index
__all__ = [
"empty_cache",
"get_memory_info",
"max_memory_allocated",
"max_memory_reserved",
"memory_allocated",
@ -88,9 +87,6 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
"""
if not torch._C._accelerator_isAllocatorInitialized():
return OrderedDict()
@ -121,9 +117,6 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
@ -141,9 +134,6 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
@ -157,9 +147,6 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
@ -177,9 +164,6 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
@ -216,21 +200,3 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_resetPeakStats(device_index)
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
r"""Return the current device memory information for a given device index.
Args:
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
The first value is the free memory on the device (available across all processes and applications),
The second value is the device's total hardware memory capacity.
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_getMemoryInfo(device_index)

View File

@ -138,13 +138,6 @@ void initModule(PyObject* module) {
at::accelerator::resetPeakStats(device_index);
});
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
py::gil_scoped_release no_gil;
return at::accelerator::getMemoryInfo(device_index);
});
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
c10::CachingAllocator::setAllocatorSettings(env);
});

View File

@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
static void initXpuMethodBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
py::gil_scoped_release no_gil;
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
#if SYCL_COMPILER_VERSION >= 20250000
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
auto& device = c10::xpu::get_raw_device(device_index);
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
at::xpu::getDeviceProperties(device_index)->name,
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
return std::make_tuple(free, total);
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
});
m.def(
"_xpu_getStreamFromExternal",

View File

@ -140,6 +140,8 @@ class ExportDynamoConfig:
capture_dynamic_output_shape_ops: bool = True
capture_scalar_outputs: bool = True
prefer_deferred_runtime_asserts_over_guards: bool = False
replay_side_effects: bool = False
side_effect_replay_policy: str = "warn"
@dataclasses.dataclass

View File

@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
int: the memory available on the device in units of bytes.
int: the total memory on the device in units of bytes
"""
_lazy_init()
device = _get_device_index(device, optional=True)
return torch._C._xpu_getMemoryInfo(device)