[AOTInductor] Add Python interface for user managed buffer. (#151141)

Summary: Add pybind for user managed buffer in update_constants_buffer.

Test Plan:
Included in commit.
```
python test/inductor/test_aot_inductor.py -k user_managed
```

Differential Revision: D72892310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151141
Approved by: https://github.com/henrylhtsang, https://github.com/desertfire
This commit is contained in:
Mu-Chu Lee
2025-04-15 09:36:30 +00:00
committed by PyTorch MergeBot
parent bd9c436c99
commit 70e7b76707
2 changed files with 105 additions and 1 deletions

View File

@ -4735,6 +4735,110 @@ class AOTInductorTestsTemplate:
runner.free_inactive_constant_buffer()
def test_update_user_managed_buffer(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 1024, 4096, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
# Attribute naming has changed in the new export API, so still use the legacy API here.
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
mem_before, _ = torch.cuda.mem_get_info(self.device)
# Do not use user managed_buffer, should have less free memory.
runner.update_constant_buffer(new_weights, True, False, False)
mem_after, _ = torch.cuda.mem_get_info(self.device)
self.assertGreater(mem_before, mem_after)
runner.swap_constant_buffer()
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output)
# Inplace substitube tensor, without user managed buffer, result should be different.
new_weights["L__self___weight"].add_(1)
new_weights["L__self___bias"].add_(1)
new_output = runner_call(test_inputs)
# Same as the previous result
self.assertEqual(new_expected, new_output)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
# Differ from latest result
self.assertNotEqual(new_expected, new_output)
# Clear out all buffers
runner.free_inactive_constant_buffer()
runner.swap_constant_buffer()
runner.free_inactive_constant_buffer()
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
mem_before, _ = torch.cuda.mem_get_info(self.device)
# Try user managed_buffer, should have same free memory.
runner.update_constant_buffer(new_weights, True, False, True)
mem_after, _ = torch.cuda.mem_get_info(self.device)
self.assertEqual(mem_before, mem_after)
runner.swap_constant_buffer()
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output)
# Inplace substitube tensor, with user managed buffer, result should be the same.
new_weights["L__self___weight"].add_(1)
new_weights["L__self___bias"].add_(1)
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output)
def test_cond_share_predicte(self):
class Model(torch.nn.Module):
def forward(self, predicate, x):

View File

@ -169,7 +169,7 @@ class MaybeOwningAtenTensorHandle {
// If user_managed is true, we do not steal the ownership.
MaybeOwningAtenTensorHandle(AtenTensorHandle handle, bool user_managed) {
if (user_managed) {
handle_ = handle;
aoti_torch_new_tensor_handle(handle, &handle_);
} else {
raii_handle_ = RAIIAtenTensorHandle(handle);
handle_ = raii_handle_.get();