Compare commits

...

9 Commits

Author SHA1 Message Date
b0d7899bc9 Automated submodule update: kineto 2025-11-12 16:33:53 -08:00
485f2b607a ProxyTorchDispatchMode: Decomposing missing sympy.SymExpr should handle constant literals (#167585)
The previous work to decompose missing sympy.SymExpr (#164717) handled combinations of sub-nodes (like `s1*s2`) but I forgot to handle explicit literals (like `2*s2`).

Added a unit test based on the report.

Fixes T244632748

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167585
Approved by: https://github.com/bobrenjc93
2025-11-13 00:27:10 +00:00
0c5d5c7e9a [dynamo][invoke_subgraph] Do not restore side effects on invoke_subgraph (#167446)
Test that checks non proxy-able outputs. Also add a test that fails to
be fixed later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167446
Approved by: https://github.com/zou3519
ghstack dependencies: #167438, #167442
2025-11-13 00:16:40 +00:00
5f98a0363a [dynamo] Make HintsWrapperHigherOrderVariable follow wrap semantics (#167442)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167442
Approved by: https://github.com/zou3519
ghstack dependencies: #167438
2025-11-13 00:16:40 +00:00
2d739001d3 [dynamo] speculate_subgraph_with_auto_output_flattening (#167438)
Summary

  This PR refactors the wrap higher-order operator infrastructure in PyTorch's Dynamo to introduce automatic output flattening for subgraph speculation. The key change is the addition of
  speculate_subgraph_with_auto_output_flattening() which separates the output variable trackers (VTs) that Dynamo continues tracing with from the actual FX graph outputs.

  Key Changes

  New speculate_subgraph_with_auto_output_flattening() function

  - Introduces a new approach for handling HOPs (Higher-Order Operators) that are just "subgraph placeholders", i.e. the HOP essentially just runs the subgraph with inputs (e.g., invoke_subgraph, activation checkpointing,
   autograd.Function)
  - Disentangles output VTs from graph outputs: Allows the subgraph to return complex Python objects (like custom user-defined objects containing tensors) while only registering tensor/symint VTs as actual FX
  graph outputs
  - Mirrors typical Dynamo processing where VTs can "run ahead" for continued tracing while the graph is a side data structure

  Benefits

  1. Handles non-proxyable outputs: Supports HOPs that return custom Python objects containing tensors
  2. Cleaner separation of concerns: Output VTs for continued tracing vs. graph outputs for FX representation
  3. More flexible: Returns graph_output_vts instead of treespec, giving more control over what becomes a graph output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167438
Approved by: https://github.com/zou3519
2025-11-13 00:16:40 +00:00
273babeec3 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-13 00:02:26 +00:00
a76dd6b7c6 [MPS] SparseMps mv op (#166708)
Should be merged after #166561
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166708
Approved by: https://github.com/malfet
2025-11-12 22:44:29 +00:00
2fa18d1545 [export] Codemod more tests to use dynamo_graph_capture_for_export (#167663)
Summary:
as title.

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167663
Approved by: https://github.com/tugsbayasgalan
2025-11-12 22:44:18 +00:00
537167aa1e Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace (#167248)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.

This diff adds mutexes to synchronize access to the static maps.

Test Plan:
Use a GPU OD, run multi-threaded tests with TSAN:
```
buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
```
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

TSAN: P2026731804

Differential Revision: D86316117

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167248
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-11-12 22:43:56 +00:00
28 changed files with 1256 additions and 455 deletions

View File

@ -3,6 +3,7 @@
#include <cstdint>
#include <map>
#include <shared_mutex>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

View File

@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
}
size_t parseChosenWorkspaceSize() {
@ -241,8 +249,10 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
return workspace_it->second.mutable_get();
}
#endif
@ -250,11 +260,34 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
auto& workspace = cublaslt_handle_stream_to_workspace();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it == workspace.map.end()) {
workspace_it =
workspace.map.emplace(key, std::move(new_workspace)).first;
}
// else: another thread inserted it, our new_workspace will be automatically
// freed
return workspace_it->second.mutable_get();
}
return workspace_it->second.mutable_get();
}
cublasHandle_t getCurrentCUDABlasHandle() {
@ -300,11 +333,39 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// all the memory and cublas's cudaMallocAsync will return OOM
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
auto& workspace = cublas_handle_stream_to_workspace();
size_t workspace_size = getChosenWorkspaceSize();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return handle;
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it == workspace.map.end()) {
workspace_it =
workspace.map.emplace(key, std::move(new_workspace)).first;
}
// else: another thread inserted it, our new_workspace will be automatically
// freed
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.

View File

@ -4389,7 +4389,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mv
SparseCPU, SparseCUDA: mv_sparse
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:

View File

@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp

View File

@ -0,0 +1,77 @@
#include <gtest/gtest.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <thread>
#include <vector>
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
// to verify that the data race fix is working correctly
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
if (!at::cuda::is_available()) {
return;
}
constexpr int num_accessor_threads = 15;
constexpr int num_clear_threads = 5;
constexpr int iterations_per_thread = 50;
std::atomic<bool> stop{false};
std::atomic<int> error_count{0};
std::vector<std::thread> threads;
threads.reserve(num_accessor_threads + num_clear_threads);
// Launch accessor threads
for (int i = 0; i < num_accessor_threads; ++i) {
threads.emplace_back([&stop, &error_count]() {
try {
at::cuda::CUDAGuard device_guard(0);
while (!stop.load(std::memory_order_relaxed)) {
const auto handle = at::cuda::getCurrentCUDABlasHandle();
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
if (handle == nullptr || workspace == nullptr) {
error_count++;
}
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Launch threads that clear workspaces
for (int i = 0; i < num_clear_threads; ++i) {
threads.emplace_back([&error_count]() {
try {
for (int j = 0; j < iterations_per_thread; ++j) {
at::cuda::clearCublasWorkspaces();
std::this_thread::yield();
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Let them run for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
stop.store(true, std::memory_order_relaxed);
for (auto& thread : threads) {
thread.join();
}
EXPECT_EQ(error_count.load(), 0);
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::cuda::CUDACachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -6,10 +6,7 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
@ -153,17 +150,6 @@ def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=N
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
@ -360,7 +346,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
)
@ -386,10 +371,6 @@ class DTensorExportTest(TestCase):
graph_capture_and_aot_export_joint_with_descriptors_v2,
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
),
(
graph_capture_and_aot_export_joint_with_descriptors,
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
),
],
)
def test_dynamic_shapes(self, export_fn_with_answer):
@ -434,7 +415,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
dynamo_graph_capture_for_export,
_dynamo_graph_capture_for_export,
],
)
def test_einsum_dtensor_export(self, export_fn):
@ -456,11 +436,7 @@ class DTensorExportTest(TestCase):
# Run model to verify it works
output = model(*inputs)
with torch._dynamo.config.patch(
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
):
# TODO: switch to use the official graph_capture API once it is ready
gm = export_fn(model)(*inputs)
gm = export_fn(model)(*inputs)
output_gm = gm(*inputs)
self.assertEqual(output, output_gm)
@ -468,7 +444,6 @@ class DTensorExportTest(TestCase):
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
],
)
def test_flex_attention_dtensor_export(self, export_fn):
@ -531,7 +506,7 @@ class DTensorExportTest(TestCase):
return nest_fn(leaf) + 1
z = torch.randn(16, 16)
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
gm = graph_capture_and_aot_export_joint_with_descriptors_v2(fn, (z,))
self.assertEqual(fn(z), gm(z)[0])
@ -546,7 +521,7 @@ class DTensorExportTest(TestCase):
y = torch.randint(1, (10,)).bool()
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
class Bar(torch.nn.Module):
def forward(self, x):
@ -556,25 +531,25 @@ class DTensorExportTest(TestCase):
x = torch.randint(1000, (4, 64, 16))
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
gm = dynamo_graph_capture_for_export(Bar())(x_dt)
self.assertExpectedInline(
str(gm.graph).strip(),
"""\
graph():
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
%l_x_ : torch.distributed.tensor.DTensor [num_users=2] = placeholder[target=L_x_]
%max_1 : [num_users=1] = call_method[target=max](args = (%l_x_,), kwargs = {})
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
return (res,)""", # noqa: B950
str(gm.graph).strip(),
return (getitem,)""", # noqa: B950
)

View File

@ -1681,14 +1681,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y, y)
return (y,)
""",
)
@ -1798,9 +1797,9 @@ class GraphModule(torch.nn.Module):
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
cos: "f32[4, 4]" = torch.cos(sin_1)
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (cos, sin_2, matmul, o, out, sin_1)
""",
)

View File

@ -1,9 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -13,13 +15,16 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -599,6 +604,92 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
matmul: "f32[3, 3]" = l_x_ @ l_y_
sin: "f32[3, 3]" = matmul.sin(); matmul = None
child: "f32[3, 3]" = sin.cos(); sin = None
cos: "f32[3, 3]" = sin.cos(); sin = None
child_1: "f32[3, 3]" = l_x_ + l_y_
child_2: "f32[3, 3]" = l_x_ - l_y_
add: "f32[3, 3]" = l_x_ + l_y_
sub: "f32[3, 3]" = l_x_ - l_y_
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (child, child_1, child_2, child_3)
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (cos, add, sub, matmul_1)
""", # noqa: B950
)
self.assertExpectedInline(

View File

@ -962,7 +962,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
x = (torch.randn(4, 16, requires_grad=True),)
with self.assertRaisesRegex(Exception, "weight = self.linear.w"):
torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
torch._dynamo.functional_export.dynamo_graph_capture_for_export(Model())(x)
instantiate_parametrized_tests(ExceptionTests)

View File

@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# when testing with dynamic shape, symbols are lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# be the input.
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
x = torch.randn(3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
def test_capture_untracked_global(self):
def f(x):
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
else:
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[3]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[3]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
child: "f32[3]" = sin + size; sin = size = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
add: "f32[3]" = sin + size; sin = size = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (add, sin_1)
""",
)
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (child, child_1)
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (add, add_1)
""",
)
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[2, 3]"):
child: "f32[2, 3]" = l_x_.sin()
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (child, child_1)
sin: "f32[2, 3]" = l_x_.sin()
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (sin, cos)
""",
)
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
value: "f32[3]" = wrap[0]; wrap = None
return (value,)
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
child: "f32[3]" = -l_x_; l_x_ = None
return (child,)
neg: "f32[3]" = -l_x_; l_x_ = None
return (neg,)
""",
)
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
hints_wrapper_body_1 = self.hints_wrapper_body_1
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (res,)
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (getitem,)
class hints_wrapper_body_1(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
hints_wrapper_body_0 = self.hints_wrapper_body_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
return (x_2,)
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
return (x_1,)
class hints_wrapper_body_0(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):

View File

@ -8146,7 +8146,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
unsafe_grad(y) # should not warn
self.assertEqual(len(w), 1)
@torch._dynamo.config.patch(install_free_tensors=True)
def test_partial_export(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -8166,14 +8165,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def forward(self, a, b):
return a + b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
foo = Foo()
foo.parallelize()
x = torch.randn(4, 4, dtype=torch.float32)
y = torch.randn(4, 4, dtype=torch.float32)
ref = foo(x, y)
gm = _dynamo_graph_capture_for_export(foo)(x, y)
gm = dynamo_graph_capture_for_export(foo)(x, y)
res = gm(x, y)
self.assertEqual(res, ref)

View File

@ -387,9 +387,9 @@ def forward(self, x):
export_inputs = ((dct, lst, 56), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
graph_module = dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)
@ -406,9 +406,9 @@ def forward(self, x):
export_inputs = ((torch.randn(4, 4),), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
graph_module = dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)

View File

@ -899,14 +899,14 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
child: "f32[8]" = mul * 2; mul = None
return (child,)
mul_1: "f32[8]" = mul * 2; mul = None
return (mul_1,)
class subgraph_1(torch.nn.Module):
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
child: "f32[8]" = mul * 3; mul = None
return (child,)
mul_1: "f32[8]" = mul * 3; mul = None
return (mul_1,)
""",
)
@ -983,20 +983,20 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
subgraph_2 = self.subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
subgraph_3 = self.subgraph_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
subgraph_4 = self.subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (x_4,)
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (getitem_4,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
@ -1495,9 +1495,9 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
child: "f32[8, 8]" = l_x_ * 2
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (child, child_1)
mul: "f32[8, 8]" = l_x_ * 2
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (mul, mul_1)
""",
)
@ -2504,6 +2504,107 @@ class GraphModule(torch.nn.Module):
self.assertEqual(f(x, other), f_compile(x, other))
self.assertTrue(called)
def test_udf_output(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
@nested_compile_region
def gn(x, y):
a = torch.sin(x)
b = torch.cos(y)
return Foo(a, b)
def fn(x, y):
foo1 = gn(x, y)
foo2 = gn(foo1.a, y)
return foo1.b + foo2.a # + foo2.b
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
ref = fn(x, y)
res = opt_fn(x_clone, y_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
l_x_ = L_x_
l_y_ = L_y_
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
getitem: "f32[8, 8]" = invoke_subgraph[0]
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = l_y_ = None
getitem_2: "f32[8, 8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add: "f32[8, 8]" = getitem_1 + getitem_2; getitem_1 = getitem_2 = None
return (add,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
a: "f32[8, 8]" = torch.sin(l_x_); l_x_ = None
b: "f32[8, 8]" = torch.cos(l_y_); l_y_ = None
return (a, b)
""",
)
# High piority - grads are wrong
@unittest.expectedFailure
def test_grad_accuracy_check(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
@nested_compile_region
def gn(x):
a = torch.sin(x)
b = torch.cos(x)
return (a, b)
def fn(x):
foo1 = gn(x)
foo2 = gn(foo1[0])
return foo1[1] + foo2[0] + foo2[1]
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
x.grad = None
x_clone.grad = None
ref = fn(x)
res = opt_fn(x_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@skipIfTorchDynamo("Not a torch._dynamo test")
@parameterized_class(

View File

@ -286,47 +286,31 @@ class GraphModule(torch.nn.Module):
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
subgraph_0 = self.subgraph_0
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
return (o_6,)
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
return (o_5,)
class subgraph_0(torch.nn.Module):
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
return (out,)
""",
return (out,)""",
ignore_empty_lines=True,
)

View File

@ -796,6 +796,27 @@ def forward(self, x_1):
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_T244632748(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + (x.shape[0] * 2)
mod = TestModule()
sample = torch.randn((5, 5)).to("cuda")
dim0 = torch.export.Dim.DYNAMIC(max=100)
dynamic_shapes = {"x": (dim0, torch.export.Dim.STATIC)}
ep = torch.export.export(mod, (sample,), dynamic_shapes=dynamic_shapes)
gm = ep.module()
symint = list(gm.graph.nodes)[3].meta["val"]
list(gm.graph.nodes)[3].replace_all_uses_with(symint)
gm.graph.eliminate_dead_code()
inductor_fx = torch._inductor.aot_compile(
gm, (sample,), options={"fx_wrapper": True, "compile_threads": 1}
)
class TestGenericProxyTensorReal(TestGenericProxyTensor):
tracing_mode = "real"

View File

@ -2674,7 +2674,6 @@ class TestSparse(TestSparseBase):
self._test_asin_arcsin(input_uncoalesced, coalesced)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_mv(self, device, dtype, coalesced):

View File

@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2672,8 +2701,10 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2700,7 +2731,10 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,6 +53,7 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -273,6 +274,7 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -11,6 +11,7 @@ import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names, check_user_input_output
from torch._dynamo.exc import UserErrorType
@ -579,9 +580,10 @@ def pytreeify(
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_out_shuffle_args)
with enable_python_dispatcher():
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="real", proxy_module_inputs=True
)(*flat_out_shuffle_args)
_normalize_shuffle_graph(out_shuffle_graph)
assert out_shuffle.out_spec is not None

File diff suppressed because it is too large Load Diff

View File

@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -1639,7 +1639,9 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2712,7 +2714,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
if V.aot_compilation and not config.enable_autograd_for_aot:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1193,6 +1193,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,9 +773,83 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
def post_compile(
self,
@ -783,10 +857,8 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
if self.current_callable is None:
self.__post_init__()
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -84,7 +84,7 @@ if TYPE_CHECKING:
from torch._ops import OpOverload
from torch.fx._symbolic_trace import PHBase
from torch.types import IntLikeType
from torch.types import BoolLikeType, FloatLikeType, IntLikeType
__all__ = [
"PythonKeyTracer",
@ -458,7 +458,7 @@ def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
def _build_proxy_for_sym_expr(
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
) -> PySymType | None:
) -> IntLikeType | FloatLikeType | BoolLikeType | None:
"""
Decompose `expr` and look for the pieces as inputs. If `out` is provided
then that will be the resulting SymNode (and `out.expr` must be the same as
@ -532,6 +532,13 @@ def _build_proxy_for_sym_expr(
assert not out
return value.value
if isinstance(expr, (int, float, bool)):
return expr
if expr.is_Integer:
return int(expr)
if expr.is_Float:
return float(expr)
args = []
for arg in expr.args:
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None: