mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Compare commits
9 Commits
ciflow/tru
...
update_sub
| Author | SHA1 | Date | |
|---|---|---|---|
| b0d7899bc9 | |||
| 485f2b607a | |||
| 0c5d5c7e9a | |||
| 5f98a0363a | |||
| 2d739001d3 | |||
| 273babeec3 | |||
| a76dd6b7c6 | |||
| 2fa18d1545 | |||
| 537167aa1e |
@ -3,6 +3,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -241,8 +249,10 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -250,11 +260,34 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it == workspace.map.end()) {
|
||||
workspace_it =
|
||||
workspace.map.emplace(key, std::move(new_workspace)).first;
|
||||
}
|
||||
// else: another thread inserted it, our new_workspace will be automatically
|
||||
// freed
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -300,11 +333,39 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it == workspace.map.end()) {
|
||||
workspace_it =
|
||||
workspace.map.emplace(key, std::move(new_workspace)).first;
|
||||
}
|
||||
// else: another thread inserted it, our new_workspace will be automatically
|
||||
// freed
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -4389,7 +4389,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
||||
@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -6,10 +6,7 @@ import unittest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
@ -153,17 +150,6 @@ def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=N
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
|
||||
|
||||
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -360,7 +346,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
)
|
||||
@ -386,10 +371,6 @@ class DTensorExportTest(TestCase):
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||
),
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||
@ -434,7 +415,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
dynamo_graph_capture_for_export,
|
||||
_dynamo_graph_capture_for_export,
|
||||
],
|
||||
)
|
||||
def test_einsum_dtensor_export(self, export_fn):
|
||||
@ -456,11 +436,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# Run model to verify it works
|
||||
output = model(*inputs)
|
||||
with torch._dynamo.config.patch(
|
||||
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
|
||||
):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = export_fn(model)(*inputs)
|
||||
gm = export_fn(model)(*inputs)
|
||||
output_gm = gm(*inputs)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
@ -468,7 +444,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
],
|
||||
)
|
||||
def test_flex_attention_dtensor_export(self, export_fn):
|
||||
@ -531,7 +506,7 @@ class DTensorExportTest(TestCase):
|
||||
return nest_fn(leaf) + 1
|
||||
|
||||
z = torch.randn(16, 16)
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors_v2(fn, (z,))
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
@ -546,7 +521,7 @@ class DTensorExportTest(TestCase):
|
||||
y = torch.randint(1, (10,)).bool()
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -556,25 +531,25 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
x = torch.randint(1000, (4, 64, 16))
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
gm = dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
self.assertExpectedInline(
|
||||
str(gm.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
|
||||
%l_x_ : torch.distributed.tensor.DTensor [num_users=2] = placeholder[target=L_x_]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_x_,), kwargs = {})
|
||||
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
|
||||
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
|
||||
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
|
||||
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
|
||||
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
|
||||
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
|
||||
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
|
||||
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
|
||||
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
|
||||
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
|
||||
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
|
||||
return (res,)""", # noqa: B950
|
||||
str(gm.graph).strip(),
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1681,14 +1681,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
|
||||
return (getitem, getitem_1)
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (y, y)
|
||||
return (y,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1798,9 +1797,9 @@ class GraphModule(torch.nn.Module):
|
||||
out: "f32[4, 4]" = l_x_.sin()
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.sin(o)
|
||||
child: "f32[4, 4]" = torch.cos(sin_1)
|
||||
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (child, child_1, matmul, o, out, sin_1)
|
||||
cos: "f32[4, 4]" = torch.cos(sin_1)
|
||||
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (cos, sin_2, matmul, o, out, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,13 +15,16 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
||||
from torch._dynamo.exc import PackageError, Unsupported
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TEST_CUDA,
|
||||
)
|
||||
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
@ -599,6 +604,92 @@ from user code:
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
||||
fn,
|
||||
(make_inputs(), {}),
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_module(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
mod = SimpleLinearModule()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4, 3),)
|
||||
|
||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
||||
mod,
|
||||
[ModelInput(make_inputs(), {}, [])],
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
def get_grads(m: torch.nn.Module):
|
||||
return {name: p.grad for name, p in m.named_parameters()}
|
||||
|
||||
original_mod = copy.deepcopy(mod)
|
||||
test_inputs = make_inputs()
|
||||
expected = mod(*test_inputs)
|
||||
expected.sum().backward()
|
||||
expected_grads = get_grads(mod)
|
||||
|
||||
actual = compiled_mod(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
serialized = compiled_mod.serialize()
|
||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
actual.sum().backward()
|
||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_torch_compile(self):
|
||||
with torch.device("cuda"):
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn, fullgraph=True, options={"use_aoti": True}
|
||||
).aot_compile((make_inputs(), {}))
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
matmul: "f32[3, 3]" = l_x_ @ l_y_
|
||||
sin: "f32[3, 3]" = matmul.sin(); matmul = None
|
||||
child: "f32[3, 3]" = sin.cos(); sin = None
|
||||
cos: "f32[3, 3]" = sin.cos(); sin = None
|
||||
|
||||
child_1: "f32[3, 3]" = l_x_ + l_y_
|
||||
child_2: "f32[3, 3]" = l_x_ - l_y_
|
||||
add: "f32[3, 3]" = l_x_ + l_y_
|
||||
sub: "f32[3, 3]" = l_x_ - l_y_
|
||||
|
||||
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (child, child_1, child_2, child_3)
|
||||
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (cos, add, sub, matmul_1)
|
||||
""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
||||
@ -962,7 +962,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
x = (torch.randn(4, 16, requires_grad=True),)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "weight = self.linear.w"):
|
||||
torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
|
||||
torch._dynamo.functional_export.dynamo_graph_capture_for_export(Model())(x)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ExceptionTests)
|
||||
|
||||
@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# when testing with dynamic shape, symbols are lifted as input
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
|
||||
def test_return_captured_vars(self):
|
||||
freevar1 = torch.randn(3)
|
||||
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
# be the input.
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
|
||||
def test_return_captured_var_used_multiple_times(self):
|
||||
freevar = torch.randn(3)
|
||||
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
|
||||
|
||||
def test_capture_untracked_global(self):
|
||||
def f(x):
|
||||
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
|
||||
child: "f32[s77]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[s77]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child: "f32[3]" = sin + size; sin = size = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
add: "f32[3]" = sin + size; sin = size = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (add, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
||||
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
|
||||
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (child, child_1)
|
||||
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[2, 3]"):
|
||||
child: "f32[2, 3]" = l_x_.sin()
|
||||
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (child, child_1)
|
||||
sin: "f32[2, 3]" = l_x_.sin()
|
||||
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (sin, cos)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
||||
value: "f32[3]" = wrap[0]; wrap = None
|
||||
return (value,)
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]"):
|
||||
child: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (child,)
|
||||
neg: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (neg,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
||||
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (res,)
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (getitem,)
|
||||
|
||||
class hints_wrapper_body_1(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
||||
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
|
||||
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
||||
return (x_2,)
|
||||
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
|
||||
return (x_1,)
|
||||
|
||||
class hints_wrapper_body_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
|
||||
@ -8146,7 +8146,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
unsafe_grad(y) # should not warn
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
@torch._dynamo.config.patch(install_free_tensors=True)
|
||||
def test_partial_export(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -8166,14 +8165,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def forward(self, a, b):
|
||||
return a + b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
foo = Foo()
|
||||
foo.parallelize()
|
||||
x = torch.randn(4, 4, dtype=torch.float32)
|
||||
y = torch.randn(4, 4, dtype=torch.float32)
|
||||
ref = foo(x, y)
|
||||
gm = _dynamo_graph_capture_for_export(foo)(x, y)
|
||||
gm = dynamo_graph_capture_for_export(foo)(x, y)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
||||
@ -387,9 +387,9 @@ def forward(self, x):
|
||||
export_inputs = ((dct, lst, 56), {})
|
||||
eager_inputs = copy.deepcopy(export_inputs)
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
||||
graph_module = dynamo_graph_capture_for_export(Foo())(
|
||||
*export_inputs[0], **export_inputs[1]
|
||||
)
|
||||
|
||||
@ -406,9 +406,9 @@ def forward(self, x):
|
||||
export_inputs = ((torch.randn(4, 4),), {})
|
||||
eager_inputs = copy.deepcopy(export_inputs)
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
||||
graph_module = dynamo_graph_capture_for_export(Foo())(
|
||||
*export_inputs[0], **export_inputs[1]
|
||||
)
|
||||
|
||||
|
||||
@ -899,14 +899,14 @@ class GraphModule(torch.nn.Module):
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
|
||||
child: "f32[8]" = mul * 2; mul = None
|
||||
return (child,)
|
||||
mul_1: "f32[8]" = mul * 2; mul = None
|
||||
return (mul_1,)
|
||||
|
||||
class subgraph_1(torch.nn.Module):
|
||||
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
|
||||
child: "f32[8]" = mul * 3; mul = None
|
||||
return (child,)
|
||||
mul_1: "f32[8]" = mul * 3; mul = None
|
||||
return (mul_1,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -983,20 +983,20 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
|
||||
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
subgraph_1 = self.subgraph_0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
|
||||
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
subgraph_2 = self.subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
|
||||
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
|
||||
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
subgraph_3 = self.subgraph_0
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
|
||||
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
|
||||
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
subgraph_4 = self.subgraph_0
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
|
||||
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
return (x_4,)
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
|
||||
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
return (getitem_4,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
@ -1495,9 +1495,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
child: "f32[8, 8]" = l_x_ * 2
|
||||
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
|
||||
return (child, child_1)
|
||||
mul: "f32[8, 8]" = l_x_ * 2
|
||||
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
|
||||
return (mul, mul_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2504,6 +2504,107 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(f(x, other), f_compile(x, other))
|
||||
self.assertTrue(called)
|
||||
|
||||
def test_udf_output(self):
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return Foo(a, b)
|
||||
|
||||
def fn(x, y):
|
||||
foo1 = gn(x, y)
|
||||
foo2 = gn(foo1.a, y)
|
||||
return foo1.b + foo2.a # + foo2.b
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
y = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x_clone, y_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
|
||||
subgraph_1 = self.subgraph_0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = l_y_ = None
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
add: "f32[8, 8]" = getitem_1 + getitem_2; getitem_1 = getitem_2 = None
|
||||
return (add,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
|
||||
a: "f32[8, 8]" = torch.sin(l_x_); l_x_ = None
|
||||
|
||||
b: "f32[8, 8]" = torch.cos(l_y_); l_y_ = None
|
||||
return (a, b)
|
||||
""",
|
||||
)
|
||||
|
||||
# High piority - grads are wrong
|
||||
@unittest.expectedFailure
|
||||
def test_grad_accuracy_check(self):
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(x)
|
||||
return (a, b)
|
||||
|
||||
def fn(x):
|
||||
foo1 = gn(x)
|
||||
foo2 = gn(foo1[0])
|
||||
return foo1[1] + foo2[0] + foo2[1]
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
x.grad = None
|
||||
x_clone.grad = None
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
@parameterized_class(
|
||||
|
||||
@ -286,47 +286,31 @@ class GraphModule(torch.nn.Module):
|
||||
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
|
||||
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
|
||||
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
|
||||
|
||||
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
|
||||
|
||||
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
|
||||
|
||||
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
|
||||
|
||||
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
|
||||
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
|
||||
|
||||
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
|
||||
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
|
||||
|
||||
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
|
||||
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
|
||||
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
|
||||
|
||||
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
|
||||
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
|
||||
|
||||
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
|
||||
|
||||
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
|
||||
|
||||
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
|
||||
|
||||
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
|
||||
|
||||
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
|
||||
|
||||
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
|
||||
return (o_6,)
|
||||
|
||||
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
|
||||
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
|
||||
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
|
||||
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
|
||||
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
|
||||
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
|
||||
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
|
||||
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
|
||||
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
|
||||
return (o_5,)
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
|
||||
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
|
||||
return (out,)
|
||||
""",
|
||||
return (out,)""",
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
|
||||
@ -796,6 +796,27 @@ def forward(self, x_1):
|
||||
|
||||
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_T244632748(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + (x.shape[0] * 2)
|
||||
|
||||
mod = TestModule()
|
||||
sample = torch.randn((5, 5)).to("cuda")
|
||||
dim0 = torch.export.Dim.DYNAMIC(max=100)
|
||||
dynamic_shapes = {"x": (dim0, torch.export.Dim.STATIC)}
|
||||
ep = torch.export.export(mod, (sample,), dynamic_shapes=dynamic_shapes)
|
||||
gm = ep.module()
|
||||
symint = list(gm.graph.nodes)[3].meta["val"]
|
||||
list(gm.graph.nodes)[3].replace_all_uses_with(symint)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
inductor_fx = torch._inductor.aot_compile(
|
||||
gm, (sample,), options={"fx_wrapper": True, "compile_threads": 1}
|
||||
)
|
||||
|
||||
|
||||
class TestGenericProxyTensorReal(TestGenericProxyTensor):
|
||||
tracing_mode = "real"
|
||||
|
||||
|
||||
@ -2674,7 +2674,6 @@ class TestSparse(TestSparseBase):
|
||||
self._test_asin_arcsin(input_uncoalesced, coalesced)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_mv(self, device, dtype, coalesced):
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: 6fcbc53d33...1e30d37905
@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
|
||||
reset_cudagraph_trees()
|
||||
|
||||
|
||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
||||
compiler_name = "aotinductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
super().__init__(mode, options, dynamic)
|
||||
self.apply_options({"cpp_wrapper": True})
|
||||
self.apply_options({"aot_inductor.package": True})
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
ctx = (
|
||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
||||
if fake_mode
|
||||
else nullcontext()
|
||||
)
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
ctx,
|
||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
||||
):
|
||||
return super().__call__(model_, inputs_)
|
||||
|
||||
|
||||
class _TorchCompileWrapper:
|
||||
def __init__(self, backend, mode, options, dynamic):
|
||||
from torch._dynamo.backends.registry import lookup_backend
|
||||
@ -2672,8 +2701,10 @@ def compile(
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
use_aoti = False
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
use_aoti = options.pop("use_aoti", False)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
@ -2700,7 +2731,10 @@ def compile(
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
if use_aoti:
|
||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ class CompileArtifacts:
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
backend_name: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
def check_compatibility(self) -> None:
|
||||
@ -273,6 +274,7 @@ def aot_compile_fullgraph(
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
||||
)
|
||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import sympy
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
|
||||
from torch._dynamo.eval_frame import argument_names, check_user_input_output
|
||||
from torch._dynamo.exc import UserErrorType
|
||||
@ -579,9 +580,10 @@ def pytreeify(
|
||||
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
|
||||
if fake_mode and fake_mode.shape_env is None:
|
||||
fake_mode.shape_env = ShapeEnv()
|
||||
out_shuffle_graph = make_fx(
|
||||
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
|
||||
)(*flat_out_shuffle_args)
|
||||
with enable_python_dispatcher():
|
||||
out_shuffle_graph = make_fx(
|
||||
out_shuffle, tracing_mode="real", proxy_module_inputs=True
|
||||
)(*flat_out_shuffle_args)
|
||||
_normalize_shuffle_graph(out_shuffle_graph)
|
||||
|
||||
assert out_shuffle.out_spec is not None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
||||
).post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
compiled_fw_func._boxed_call = True
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if needs_autograd:
|
||||
|
||||
@ -1639,7 +1639,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
return CompiledAOTI(
|
||||
filename=compiled_fn, device_type=graph.device_type
|
||||
)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore [unbound-name]
|
||||
@ -2712,7 +2714,7 @@ def _compile_fx_main(
|
||||
or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
if V.aot_compilation:
|
||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
||||
from .utils import is_valid_aoti_model_name
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
|
||||
@ -1193,6 +1193,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
||||
|
||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||
|
||||
enable_autograd_for_aot: bool = False
|
||||
|
||||
|
||||
def get_worker_log_path() -> Optional[str]:
|
||||
log_loc = None
|
||||
|
||||
@ -773,9 +773,83 @@ class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
device_type: str
|
||||
current_callable: Optional[Callable[..., Any]] = None
|
||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not config.aot_inductor.link_libtorch:
|
||||
return
|
||||
|
||||
if (
|
||||
torch._inductor.cpp_builder._IS_MACOS
|
||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
||||
):
|
||||
return
|
||||
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
return
|
||||
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
return
|
||||
|
||||
if isinstance(self.filename, list):
|
||||
current_callable = next(
|
||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
||||
)
|
||||
else:
|
||||
current_callable = self.filename
|
||||
|
||||
if isinstance(current_callable, torch.fx.GraphModule):
|
||||
self.current_callable = current_callable
|
||||
return
|
||||
|
||||
if self.device_type.startswith("cuda"):
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
||||
current_callable,
|
||||
1,
|
||||
self.device_type,
|
||||
"",
|
||||
True,
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
elif self.device_type == "cpu":
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
||||
current_callable, 1
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
||||
self.current_callable = current_callable
|
||||
self._boxed_call = True
|
||||
for file in self._cached_files:
|
||||
if not os.path.exists(file):
|
||||
with open(file, "wb") as f:
|
||||
f.write(self._cached_files[file])
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
if self.current_callable is None:
|
||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
||||
return self.current_callable(inputs)
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
self.current_callable = None
|
||||
self._cached_files = {}
|
||||
filenames: list[str] = []
|
||||
if isinstance(self.filename, list):
|
||||
filenames = self.filename # type: ignore[assignment]
|
||||
elif isinstance(self.filename, str):
|
||||
filenames = [self.filename]
|
||||
for name in filenames:
|
||||
with open(name, "rb") as f:
|
||||
self._cached_files[name] = f.read()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["current_callable"] = None
|
||||
return state
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -783,10 +857,8 @@ class CompiledAOTI(OutputCode):
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
if self.current_callable is None:
|
||||
self.__post_init__()
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool>())
|
||||
.def(
|
||||
"run",
|
||||
&AOTIModelContainerRunnerCuda::run,
|
||||
|
||||
@ -84,7 +84,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._symbolic_trace import PHBase
|
||||
from torch.types import IntLikeType
|
||||
from torch.types import BoolLikeType, FloatLikeType, IntLikeType
|
||||
|
||||
__all__ = [
|
||||
"PythonKeyTracer",
|
||||
@ -458,7 +458,7 @@ def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
|
||||
|
||||
def _build_proxy_for_sym_expr(
|
||||
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
|
||||
) -> PySymType | None:
|
||||
) -> IntLikeType | FloatLikeType | BoolLikeType | None:
|
||||
"""
|
||||
Decompose `expr` and look for the pieces as inputs. If `out` is provided
|
||||
then that will be the resulting SymNode (and `out.expr` must be the same as
|
||||
@ -532,6 +532,13 @@ def _build_proxy_for_sym_expr(
|
||||
assert not out
|
||||
return value.value
|
||||
|
||||
if isinstance(expr, (int, float, bool)):
|
||||
return expr
|
||||
if expr.is_Integer:
|
||||
return int(expr)
|
||||
if expr.is_Float:
|
||||
return float(expr)
|
||||
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
|
||||
|
||||
Reference in New Issue
Block a user