mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTInductor] Add Python interface for user managed buffer. (#151141)
Summary: Add pybind for user managed buffer in update_constants_buffer. Test Plan: Included in commit. ``` python test/inductor/test_aot_inductor.py -k user_managed ``` Differential Revision: D72892310 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151141 Approved by: https://github.com/henrylhtsang, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
bd9c436c99
commit
70e7b76707
@ -4735,6 +4735,110 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
runner.free_inactive_constant_buffer()
|
||||
|
||||
def test_update_user_managed_buffer(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(n, k, device=device)
|
||||
self.bias = torch.randn(n, device=device)
|
||||
|
||||
def forward(self, a):
|
||||
return torch.nn.functional.linear(a, self.weight, self.bias)
|
||||
|
||||
M, N, K = 1024, 4096, 4096
|
||||
model = Model(N, K, self.device)
|
||||
a = torch.randn(M, K, device=self.device)
|
||||
example_inputs = (a,)
|
||||
# Attribute naming has changed in the new export API, so still use the legacy API here.
|
||||
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
||||
so_path = AOTIRunnerUtil.legacy_compile(
|
||||
model=model,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
|
||||
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
||||
|
||||
def runner_call(*args, **kwargs):
|
||||
import torch.fx._pytree as fx_pytree
|
||||
|
||||
call_spec = runner.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
||||
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
||||
flat_outputs = runner.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
test_inputs = torch.randn(M, K, device=self.device)
|
||||
expected = model(test_inputs)
|
||||
output = runner_call(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
new_weights = {
|
||||
"L__self___weight": torch.randn(N, K, device=self.device),
|
||||
"L__self___bias": torch.randn(N, device=self.device),
|
||||
}
|
||||
mem_before, _ = torch.cuda.mem_get_info(self.device)
|
||||
# Do not use user managed_buffer, should have less free memory.
|
||||
runner.update_constant_buffer(new_weights, True, False, False)
|
||||
mem_after, _ = torch.cuda.mem_get_info(self.device)
|
||||
self.assertGreater(mem_before, mem_after)
|
||||
|
||||
runner.swap_constant_buffer()
|
||||
new_output = runner_call(test_inputs)
|
||||
new_expected = torch.nn.functional.linear(
|
||||
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
||||
)
|
||||
self.assertEqual(new_expected, new_output)
|
||||
|
||||
# Inplace substitube tensor, without user managed buffer, result should be different.
|
||||
new_weights["L__self___weight"].add_(1)
|
||||
new_weights["L__self___bias"].add_(1)
|
||||
|
||||
new_output = runner_call(test_inputs)
|
||||
# Same as the previous result
|
||||
self.assertEqual(new_expected, new_output)
|
||||
new_expected = torch.nn.functional.linear(
|
||||
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
||||
)
|
||||
# Differ from latest result
|
||||
self.assertNotEqual(new_expected, new_output)
|
||||
|
||||
# Clear out all buffers
|
||||
runner.free_inactive_constant_buffer()
|
||||
runner.swap_constant_buffer()
|
||||
runner.free_inactive_constant_buffer()
|
||||
|
||||
new_weights = {
|
||||
"L__self___weight": torch.randn(N, K, device=self.device),
|
||||
"L__self___bias": torch.randn(N, device=self.device),
|
||||
}
|
||||
mem_before, _ = torch.cuda.mem_get_info(self.device)
|
||||
# Try user managed_buffer, should have same free memory.
|
||||
runner.update_constant_buffer(new_weights, True, False, True)
|
||||
mem_after, _ = torch.cuda.mem_get_info(self.device)
|
||||
self.assertEqual(mem_before, mem_after)
|
||||
|
||||
runner.swap_constant_buffer()
|
||||
new_output = runner_call(test_inputs)
|
||||
new_expected = torch.nn.functional.linear(
|
||||
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
||||
)
|
||||
self.assertEqual(new_expected, new_output)
|
||||
|
||||
# Inplace substitube tensor, with user managed buffer, result should be the same.
|
||||
new_weights["L__self___weight"].add_(1)
|
||||
new_weights["L__self___bias"].add_(1)
|
||||
|
||||
new_output = runner_call(test_inputs)
|
||||
new_expected = torch.nn.functional.linear(
|
||||
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
||||
)
|
||||
self.assertEqual(new_expected, new_output)
|
||||
|
||||
def test_cond_share_predicte(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, predicate, x):
|
||||
|
@ -169,7 +169,7 @@ class MaybeOwningAtenTensorHandle {
|
||||
// If user_managed is true, we do not steal the ownership.
|
||||
MaybeOwningAtenTensorHandle(AtenTensorHandle handle, bool user_managed) {
|
||||
if (user_managed) {
|
||||
handle_ = handle;
|
||||
aoti_torch_new_tensor_handle(handle, &handle_);
|
||||
} else {
|
||||
raii_handle_ = RAIIAtenTensorHandle(handle);
|
||||
handle_ = raii_handle_.get();
|
||||
|
Reference in New Issue
Block a user