Compare commits

...

4 Commits

Author SHA1 Message Date
ce18095961 no dynamic 2025-03-07 17:40:07 -08:00
1f71ac6c0e bench 2025-03-07 17:38:51 -08:00
c5a69b2ad8 [ca] always do initial trace with dynamic shapes
ghstack-source-id: 1b74c2a6738370ec8ebbff3183fb5d1533f181a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148801
2025-03-07 17:37:39 -08:00
60b25eb481 [ca] support for dynamic shapes CopySlices
ghstack-source-id: 0d6a7e8c971e9a55600af59e32e5d70f874a1656
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148799
2025-03-07 17:37:36 -08:00
6 changed files with 142 additions and 113 deletions

View File

@ -3241,6 +3241,7 @@ def parse_args(args=None):
"--compiled-autograd",
action="store_true",
help="Enables compiled autograd on compiled benchmark",
default=True
)
parser.add_argument(

View File

@ -1164,7 +1164,7 @@ main()
model.zero_grad()
# TODO(jansel): we should be able to get this count to 1
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_dynamic_shapes_eager_node(self):
# Here, we have no way of marking the symbolic sizes using in SumBackward as dynamic
@ -1190,7 +1190,7 @@ main()
yield model[2].bias.grad
model.zero_grad()
self.check_output_and_recompiles(fn, count=3)
self.check_output_and_recompiles(fn)
def test_torch_compile_api_dynamic_shapes(self):
# Here, we have no way of marking the symbolic sizes using in SumBackward as dynamic
@ -1395,7 +1395,7 @@ main()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_custom_fn_saved_multiple_tensors(self):
def fn():
@ -1418,7 +1418,7 @@ main()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_custom_fn_saved_multiple_tensors_dedup(self):
def fn():
@ -1440,7 +1440,7 @@ main()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_custom_fn_saved_shape_tensor(self):
def fn():
@ -1462,7 +1462,7 @@ main()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_custom_fn_saved_attr(self):
def fn():
@ -1485,7 +1485,7 @@ main()
yield x.grad
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False)
fn, compiler_fn=make_compiler_fn(fullgraph=False)
)
def test_custom_fn_multiple_grads(self):
@ -1508,7 +1508,7 @@ main()
yield x.grad
yield y.grad
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
def test_custom_fn_non_variable_input(self):
def fn():
@ -1532,7 +1532,7 @@ main()
yield y
yield z
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
@unittest.skipIf(not HAS_GPU, "requires gpu")
def test_logging_tensor_flaky(self) -> None:
@ -1699,7 +1699,7 @@ main()
yield x.grad
self.check_output_and_recompiles(
fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
)
def test_custom_fn_compiled_fw_graph_break(self):
@ -1725,9 +1725,9 @@ main()
yield x.grad
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=1, compiler_fn=make_compiler_fn(fullgraph=False)
)
self.assertEqual(counters["stats"]["unique_graphs"], 5) # 3 fw, 2 bw
self.assertEqual(counters["stats"]["unique_graphs"], 4) # 3 fw, 1 bw
def test_custom_fn_compiled_fw_bw_graph_break(self):
def fn():
@ -1753,9 +1753,9 @@ main()
yield x.grad
self.check_output_and_recompiles(
fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
)
self.assertEqual(counters["stats"]["unique_graphs"], 9) # 3 fw, 6 bw
self.assertEqual(counters["stats"]["unique_graphs"], 6) # 3 fw, 3 bw
def test_mismatch_fake_tensor_mode(self, dynamic_shape=False):
"""
@ -2135,12 +2135,11 @@ TORCH_LIBRARY(test_autograd_cpp_node_basic_$is_traceable, m) {
yield x.grad
if is_traceable:
# compiles for 10 (static) and 100 (dynamic)
self.check_output_and_recompiles(fn, 2)
self.check_output_and_recompiles(fn, 1)
else:
# compiles for 10 (static) and 100 (dynamic), each with a graph break
self.check_output_and_recompiles(
fn, count=[2, 4], compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=[1, 2], compiler_fn=make_compiler_fn(fullgraph=False)
)
@parametrize("is_traceable", (True, False))
@ -2332,10 +2331,10 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_basic_$is_traceable, m) {
yield x.grad
if is_traceable:
self.check_output_and_recompiles(fn, 2)
self.check_output_and_recompiles(fn, 1)
else:
self.check_output_and_recompiles(
fn, count=[2, 4], compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=[1, 2], compiler_fn=make_compiler_fn(fullgraph=False)
)
@parametrize("is_traceable", (True, False))
@ -2397,10 +2396,10 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic_$is_traceable, m) {
# compiles for 10 (static) and 100 (dynamic)
if is_traceable:
self.check_output_and_recompiles(fn, 2)
self.check_output_and_recompiles(fn, 1)
else:
self.check_output_and_recompiles(
fn, count=[2, 4], compiler_fn=make_compiler_fn(fullgraph=False)
fn, count=[1, 2], compiler_fn=make_compiler_fn(fullgraph=False)
)
@parametrize("is_traceable", (True, False))
@ -3186,19 +3185,13 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
)
with ctx():
self.check_output_and_recompiles(fn, count=2)
self.check_output_and_recompiles(fn)
patterns1 = [
r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), "
r"previous key sizes=\[\]\n",
]
# recompile
patterns2 = [
r".*Cache miss due to 7 changed tensor shapes \(total of 7\): ",
r"sizes\[0\], sizes\[1\], sizes\[2\], sizes\[3\], sizes\[4\], sizes\[5\], sizes\[6\]\n",
]
all_logs = logs.getvalue()
pattern1 = r"".join(patterns1)
@ -3209,10 +3202,6 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]...
self.assertEqual(len(matches1), len(patterns1))
pattern2 = r"".join(patterns2)
matches2 = re.findall(pattern2, all_logs)
self.assertEqual(len(matches2), 1)
def test_verbose_logs_dynamic_shapes(self):
logs, ctx = logs_to_string(
torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
@ -3233,19 +3222,11 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
with ctx(), compiled_autograd._enable(torch.compile(backend="eager")):
result.backward()
self.assertEqual(counters["compiled_autograd"]["captures"], 3)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
actual_logs = logs.getvalue()
expected_logs = [
"Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]",
(
"Cache miss due to 7 changed tensor shapes (total of 14): "
"sizes[0], sizes[1], sizes[2], sizes[3], sizes[4], sizes[5], sizes[6]"
),
(
"Cache miss due to 7 changed tensor shapes (total of 14): "
"sizes[0], sizes[1], sizes[2], sizes[3], sizes[4], sizes[5], sizes[6]"
),
]
for expected in expected_logs:
self.assertTrue(expected in actual_logs)
@ -3356,17 +3337,37 @@ class CompiledAutograd0(torch.nn.Module):
getitem_2 = inputs[2]
getitem_3 = inputs[3]
getitem_4 = inputs[4]; inputs = None
getitem_5 = sizes[0]
getitem_6 = sizes[1]
getitem_7 = sizes[2]
getitem_8 = sizes[3]
getitem_9 = sizes[4]; getitem_9 = None
getitem_10 = sizes[5]; getitem_10 = None
getitem_11 = sizes[6]; getitem_11 = None
getitem_12 = sizes[7]; getitem_12 = None
getitem_13 = sizes[8]; getitem_13 = None
getitem_14 = sizes[9]; getitem_14 = None
getitem_15 = sizes[10]; getitem_15 = None
getitem_16 = sizes[11]; getitem_16 = None
getitem_17 = sizes[12]; getitem_17 = None
getitem_18 = sizes[13]; getitem_18 = None
getitem_19 = sizes[14]; getitem_19 = None
getitem_20 = sizes[15]; getitem_20 = None
getitem_21 = sizes[16]
getitem_22 = sizes[17]
getitem_23 = sizes[18]
getitem_24 = sizes[19]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
getitem_5 = validate_outputs[0]; validate_outputs = None
getitem_25 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None
getitem_6 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [getitem_5, getitem_6]); getitem_25 = getitem_5 = getitem_6 = None
getitem_26 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [getitem_7, getitem_8], True)]); getitem_26 = getitem_7 = getitem_8 = None
getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_8 = hooks[0]; getitem_8 = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None
getitem_28 = hooks[0]; getitem_28 = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_27); getitem_1 = getitem_2 = getitem_27 = None
aot0_primals_1 = call_aot_bwd_prologue[0]
aot0_primals_2 = call_aot_bwd_prologue[1]
aot0_tangents_1 = call_aot_bwd_prologue[2]
@ -3380,18 +3381,18 @@ class CompiledAutograd0(torch.nn.Module):
make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None
getitem_13 = hooks[1]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None
getitem_16 = call_backward[0]
getitem_17 = call_backward[1]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None
getitem_19 = validate_outputs_2[0]
getitem_33 = hooks[1]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None
getitem_36 = call_backward[0]
getitem_37 = call_backward[1]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [getitem_21, getitem_22], False), ((None, None, device(type='cpu'), 6, 0, None), [getitem_23, getitem_24], False)]); getitem_36 = getitem_37 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None
getitem_39 = validate_outputs_2[0]
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_39); getitem_4 = getitem_39 = accumulate_grad__1 = None
getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None
getitem_40 = validate_outputs_2[1]; validate_outputs_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_40); getitem_3 = getitem_40 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
@ -3589,46 +3590,58 @@ class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
getitem_2 = sizes[0]
getitem_3 = sizes[1]
getitem_4 = sizes[2]
getitem_5 = sizes[3]
getitem_6 = sizes[4]
getitem_7 = sizes[5]
getitem_8 = sizes[6]
getitem_9 = sizes[7]
getitem_10 = sizes[8]
getitem_11 = sizes[9]
getitem_12 = sizes[10]
getitem_13 = sizes[11]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_2 = validate_outputs[0]; validate_outputs = None
getitem_14 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None
getitem_3 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None
getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [getitem_2, getitem_3]); getitem_14 = getitem_2 = getitem_3 = None
getitem_15 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4, getitem_5], False)]); getitem_15 = getitem_4 = getitem_5 = None
getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_5 = hooks[0]
getitem_6 = packed_data[0]
getitem_7 = hooks[1]
getitem_8 = packed_data[1]
call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None
getitem_9 = mul_backward0[0]
getitem_10 = mul_backward0[1]; mul_backward0 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None
getitem_11 = validate_outputs_2[0]
getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None
getitem_17 = hooks[0]
getitem_18 = packed_data[0]
getitem_19 = hooks[1]
getitem_20 = packed_data[1]
call_hook = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_19, getitem_20, hook_type = 'unpack_hook'); getitem_19 = getitem_20 = None
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None
getitem_21 = mul_backward0[0]
getitem_22 = mul_backward0[1]; mul_backward0 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [getitem_6, getitem_7], False), ((None, None, device(type='cpu'), 6, 0, None), [getitem_8, getitem_9], False)]); getitem_21 = getitem_22 = getitem_6 = getitem_7 = getitem_8 = getitem_9 = None
getitem_23 = validate_outputs_2[0]
getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None
getitem_13 = hooks[2]
getitem_14 = packed_data[2]
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None
getitem_15 = cos_backward0[0]; cos_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None
getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None
add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None
getitem_25 = hooks[2]
getitem_26 = packed_data[2]
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None
getitem_27 = cos_backward0[0]; cos_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [getitem_10, getitem_11], False)]); getitem_27 = getitem_10 = getitem_11 = None
getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None
add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None
getitem_17 = hooks[3]; hooks = None
getitem_18 = packed_data[3]; packed_data = None
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
getitem_29 = hooks[3]; hooks = None
getitem_30 = packed_data[3]; packed_data = None
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
getitem_19 = sin_backward0[0]; sin_backward0 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None
getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None
getitem_31 = sin_backward0[0]; sin_backward0 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [getitem_12, getitem_13], False)]); getitem_31 = getitem_12 = getitem_13 = None
getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_32); getitem_1 = getitem_32 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
@ -3715,7 +3728,7 @@ class CompiledAutograd1(torch.nn.Module):
)
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check)
fn, compiler_fn=make_compiler_fn(gm_hook=check)
)
@skipIfWindows(msg="temp dir not compatible")
@ -3797,7 +3810,7 @@ class CompiledAutograd1(torch.nn.Module):
# 1 graph break on torch.load -> 2 dynamo graphs
self.check_output_and_recompiles(
fn,
count=[2, 4],
count=[1, 2],
compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check),
)

View File

@ -502,15 +502,24 @@ class AutogradCompilerInstance:
self.bind_objects_to_proxies(grad_ins, proxies)
return tuple(grad_ins)
def call_copy_slices_prologue(self, inputs, base, view):
def call_copy_slices_prologue(
self,
inputs,
base_sizes,
base_strides,
base_storage_offset,
view_sizes,
view_strides,
view_storage_offset,
):
args = (
inputs,
base.sizes(),
base.strides(),
base.storage_offset(),
view.sizes(),
view.strides(),
view.storage_offset(),
self.to_proxy(base_sizes),
self.to_proxy(base_strides),
self.to_proxy(base_storage_offset),
self.to_proxy(view_sizes),
self.to_proxy(view_strides),
self.to_proxy(view_storage_offset),
)
return self.proxy_call(copy_slices_prologue, args, [None] * 3)

View File

@ -270,9 +270,9 @@ variable_list CopySlices::apply_with_saved(
TORCH_INTERNAL_ASSERT(stuff.size() == 3);
// These variables are named the same as in CopySlices::apply_impl.
// Follow along there.
auto result = stuff[0];
auto grad_slice = stuff[1];
auto grad_slice_clone = stuff[2];
const auto& result = stuff[0];
const auto& grad_slice = stuff[1];
const auto& grad_slice_clone = stuff[2];
auto res = fn->apply_with_saved({grad_slice_clone}, saved);
results = interface->call_copy_slices_epilogue(
saved.get_py_compiler(), needs_input_grad, result, res, grad_slice);

View File

@ -39,7 +39,7 @@ struct TORCH_API PyCompilerInterface {
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) {
bool is_traceable = true) const {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
@ -51,14 +51,14 @@ struct TORCH_API PyCompilerInterface {
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
const c10::IValue& output_metadata) const {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
const at::TensorGeometry& view) const {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
@ -66,13 +66,13 @@ struct TORCH_API PyCompilerInterface {
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
const at::Tensor& grad_slice) const {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) {
size_t hook_input_id) const {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};

View File

@ -163,7 +163,7 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
functional_apply_t fn,
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) override {
bool is_traceable = true) const override {
return torch::dynamo::autograd::bind_function(
py_compiler,
fn_name,
@ -178,7 +178,7 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) override {
const c10::IValue& output_metadata) const override {
return torch::dynamo::autograd::call_function(
py_compiler,
method_name,
@ -191,10 +191,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) override {
const at::TensorGeometry& view) const override {
py::handle handle(py_compiler);
py::object stuff =
handle.attr("call_copy_slices_prologue")(inputs, base, view);
py::object stuff = handle.attr("call_copy_slices_prologue")(
inputs,
base.sym_sizes(),
base.sym_strides(),
base.sym_storage_offset(),
view.sym_sizes(),
view.sym_strides(),
view.sym_storage_offset());
return py::cast<std::vector<at::Tensor>>(stuff);
}
variable_list call_copy_slices_epilogue(
@ -202,7 +208,7 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) override {
const at::Tensor& grad_slice) const override {
py::handle handle(py_compiler);
py::object stuff = handle.attr("call_copy_slices_epilogue")(
needs_input_grad, result, res, grad_slice);
@ -212,7 +218,7 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) override {
size_t hook_input_id) const override {
py::handle handle(py_compiler);
py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id);
auto tmp = py::cast<std::optional<at::Tensor>>(proxy);