Update gm.print_readable to include Annotation (#165397)

Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305
This commit is contained in:
Sherlock Huang
2025-10-16 20:37:07 -07:00
committed by PyTorch MergeBot
parent e4454947e2
commit 7a65770013
7 changed files with 30 additions and 63 deletions

View File

@ -3802,7 +3802,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -3933,7 +3932,6 @@ class GraphModule(torch.nn.Module):
tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal) tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None
child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -4146,7 +4144,6 @@ class GraphModule(torch.nn.Module):
primals_out: "f32[3, 4]" = diff_primals.sin() primals_out: "f32[3, 4]" = diff_primals.sin()
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1) results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1)
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4381,7 +4378,6 @@ class GraphModule(torch.nn.Module):
primals_out: "f32[]" = sin.sum(); sin = None primals_out: "f32[]" = sin.sum(); sin = None
aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None
results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4571,7 +4567,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4639,7 +4634,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4696,7 +4690,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4753,7 +4746,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4808,9 +4800,7 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4866,9 +4856,7 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4942,9 +4930,7 @@ class GraphModule(torch.nn.Module):
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -4988,9 +4974,7 @@ class GraphModule(torch.nn.Module):
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -5050,7 +5034,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -5060,7 +5043,6 @@ class GraphModule(torch.nn.Module):
grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None
grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None
_grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None
@ -5166,7 +5148,6 @@ class GraphModule(torch.nn.Module):
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
@ -5245,7 +5226,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5327,7 +5307,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5411,7 +5390,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5502,7 +5480,6 @@ class GraphModule(torch.nn.Module):
child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None
child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None
child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
@ -5572,7 +5549,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5626,7 +5602,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5688,7 +5663,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5742,7 +5716,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5810,7 +5783,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
@ -5887,7 +5859,6 @@ class GraphModule(torch.nn.Module):
dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None
tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None
@ -5902,7 +5873,6 @@ class GraphModule(torch.nn.Module):
_unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None
_unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None
_unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None
_unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None

View File

@ -3166,7 +3166,6 @@ class GraphModule(torch.nn.Module):
): ):
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return ( return (
None, # None None, # None

View File

@ -16061,6 +16061,7 @@ class GraphModule(torch.nn.Module):
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
return (add,) return (add,)
""", """,
ignore_empty_lines=True,
) )
ep = export(M(), (x, y), strict=strict).run_decompositions({}) ep = export(M(), (x, y), strict=strict).run_decompositions({})
@ -16093,6 +16094,7 @@ class GraphModule(torch.nn.Module):
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
return (add,) return (add,)
""", """,
ignore_empty_lines=True,
) )
@testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation @testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation

View File

@ -8104,7 +8104,6 @@ class GraphModule(torch.nn.Module):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None _guards_fn = self._guards_fn(x); _guards_fn = None
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_cond_graph_0 = self.while_loop_cond_graph_0
@ -8404,7 +8403,6 @@ class GraphModule(torch.nn.Module):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None _guards_fn = self._guards_fn(x); _guards_fn = None
sym_size_int_1: "Sym(s6)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_1: "Sym(s6)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s6, 3]" = torch.ops.aten.sin.default(x); x = None sin: "f32[s6, 3]" = torch.ops.aten.sin.default(x); x = None
@ -8691,10 +8689,8 @@ class GraphModule(torch.nn.Module):
t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None
mul_4: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select) mul_4: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select)
mul_5: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select); arg1_1 = select = None mul_5: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select); arg1_1 = select = None
add_7: "f32[3, 3]" = torch.ops.aten.add.Tensor(mm, mul_5); mm = mul_5 = None add_7: "f32[3, 3]" = torch.ops.aten.add.Tensor(mm, mul_5); mm = mul_5 = None
add_8: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_7, mul_4); add_7 = mul_4 = None add_8: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_7, mul_4); add_7 = mul_4 = None
add_9: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None add_9: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None
add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None
@ -8909,7 +8905,6 @@ class GraphModule(torch.nn.Module):
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
_guards_fn = self._guards_fn(x, y, z); _guards_fn = None _guards_fn = self._guards_fn(x, y, z); _guards_fn = None
sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None
sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0) sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0)

View File

@ -17,6 +17,7 @@ from functorch.compile import aot_function, nop
from torch._dynamo.testing import ( from torch._dynamo.testing import (
AotEagerAndRecordGraphs, AotEagerAndRecordGraphs,
EagerAndRecordGraphs, EagerAndRecordGraphs,
empty_line_normalizer,
InductorAndRecordGraphs, InductorAndRecordGraphs,
normalize_gm, normalize_gm,
) )
@ -351,10 +352,8 @@ class GraphModule(torch.nn.Module):
getitem_14: "f32[8]" = invoke_subgraph_6[2] getitem_14: "f32[8]" = invoke_subgraph_6[2]
getitem_13: "f32[8]" = invoke_subgraph_6[1] getitem_13: "f32[8]" = invoke_subgraph_6[1]
getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
add: "f32[8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None add: "f32[8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return (add, getitem_12, getitem_11, getitem_10, getitem_15, getitem_14, getitem_13) return (add, getitem_12, getitem_11, getitem_10, getitem_15, getitem_14, getitem_13)
class partitioned_fw_subgraph_0_0(torch.nn.Module): class partitioned_fw_subgraph_0_0(torch.nn.Module):
def forward(self, primals_0: "f32[8]", primals_1: "f32[8]", primals_2: "f32[8]"): def forward(self, primals_0: "f32[8]", primals_1: "f32[8]", primals_2: "f32[8]"):
mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1)
@ -363,6 +362,7 @@ class GraphModule(torch.nn.Module):
mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(mul_1, primals_2); mul_1 = None mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(mul_1, primals_2); mul_1 = None
return (mul_2, primals_0, primals_1, primals_2) return (mul_2, primals_0, primals_1, primals_2)
""", """,
ignore_empty_lines=True,
) )
self.assertExpectedInline( self.assertExpectedInline(
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
@ -377,7 +377,6 @@ class GraphModule(torch.nn.Module):
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None
getitem_6: "f32[8]" = invoke_subgraph_5[0] getitem_6: "f32[8]" = invoke_subgraph_5[0]
getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None
add_1: "f32[8]" = torch.ops.aten.add.Tensor(getitem_2, getitem_6); getitem_2 = getitem_6 = None add_1: "f32[8]" = torch.ops.aten.add.Tensor(getitem_2, getitem_6); getitem_2 = getitem_6 = None
add_2: "f32[8]" = torch.ops.aten.add.Tensor(getitem_3, getitem_7); getitem_3 = getitem_7 = None add_2: "f32[8]" = torch.ops.aten.add.Tensor(getitem_3, getitem_7); getitem_3 = getitem_7 = None
return (add_1, add_2, None) return (add_1, add_2, None)
@ -393,6 +392,7 @@ class GraphModule(torch.nn.Module):
mul_7: "f32[8]" = torch.ops.aten.mul.Tensor(mul_5, primals_1); mul_5 = primals_1 = None mul_7: "f32[8]" = torch.ops.aten.mul.Tensor(mul_5, primals_1); mul_5 = primals_1 = None
return (mul_7, mul_6, None) return (mul_7, mul_6, None)
""", """,
ignore_empty_lines=True,
) )
def test_buffer_mutation_works_under_no_grad(self): def test_buffer_mutation_works_under_no_grad(self):
@ -681,6 +681,7 @@ class GraphModule(torch.nn.Module):
sin: "f32[8]" = torch.ops.aten.sin.default(primals_0) sin: "f32[8]" = torch.ops.aten.sin.default(primals_0)
return (sin, primals_0) return (sin, primals_0)
""", """,
ignore_empty_lines=True,
) )
@inductor_config.patch("fx_graph_cache", False) @inductor_config.patch("fx_graph_cache", False)
@ -722,6 +723,7 @@ class <lambda>(torch.nn.Module):
mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None
return (mul_1,) return (mul_1,)
""", """,
ignore_empty_lines=True,
) )
def test_dedupe(self): def test_dedupe(self):
@ -770,7 +772,6 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0 subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0 subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', a, l_y_); subgraph_1 = a = l_y_ = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', a, l_y_); subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
@ -806,6 +807,7 @@ class GraphModule(torch.nn.Module):
mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1)
return (mul, primals_0, primals_1) return (mul, primals_0, primals_1)
""", """,
ignore_empty_lines=True,
) )
def test_dce(self): def test_dce(self):
@ -889,7 +891,6 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0 subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_1 subgraph_1 = self.subgraph_1
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', a, l_y_); subgraph_1 = a = l_y_ = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', a, l_y_); subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
@ -1535,7 +1536,6 @@ class GraphModule(torch.nn.Module):
def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"):
mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3)
mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add,) return (add,)
""", """,
@ -2145,7 +2145,6 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0 subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x, y); subgraph_0 = x = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x, y); subgraph_0 = x = None
z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_1 subgraph_1 = self.subgraph_1
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', z, y); subgraph_1 = z = y = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', z, y); subgraph_1 = z = y = None
getitem_1: "f32[5]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None getitem_1: "f32[5]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
@ -2283,6 +2282,7 @@ class GraphModule(torch.nn.Module):
cos: "f32[s77, 16]" = torch.ops.aten.cos.default(primals_1) cos: "f32[s77, 16]" = torch.ops.aten.cos.default(primals_1)
return (cos, primals_1, primals_0) return (cos, primals_1, primals_0)
""", """,
ignore_empty_lines=True,
) )
self.assertExpectedInline( self.assertExpectedInline(
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
@ -2294,7 +2294,6 @@ class GraphModule(torch.nn.Module):
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_23, getitem_22, expand); partitioned_bw_subgraph_0_0 = getitem_23 = getitem_22 = None invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_23, getitem_22, expand); partitioned_bw_subgraph_0_0 = getitem_23 = getitem_22 = None
getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None
add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None
partitioned_bw_subgraph_0_3 = self.partitioned_bw_subgraph_0_1 partitioned_bw_subgraph_0_3 = self.partitioned_bw_subgraph_0_1
@ -2326,6 +2325,7 @@ class GraphModule(torch.nn.Module):
mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(tangents_0, neg); tangents_0 = neg = None mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(tangents_0, neg); tangents_0 = neg = None
return (None, mul_10) return (None, mul_10)
""", """,
ignore_empty_lines=True,
) )
def test_div(self): def test_div(self):
@ -2535,19 +2535,19 @@ class TestInvokeSubgraphExport(TestCase):
self.assertEqual(len(list(ep.graph_module.named_modules())), 2) self.assertEqual(len(list(ep.graph_module.named_modules())), 2)
self.assertExpectedInline( self.assertExpectedInline(
normalize_gm(ep.graph_module.print_readable(print_output=False)), empty_line_normalizer(
normalize_gm(ep.graph_module.print_readable(print_output=False))
),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, x: "f32[8]", y: "f32[8]"): def forward(self, x: "f32[8]", y: "f32[8]"):
repeated_subgraph0 = self.repeated_subgraph0 repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x, y); repeated_subgraph0 = x = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x, y); repeated_subgraph0 = x = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0 repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, y); repeated_subgraph0_1 = getitem = y = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, y); repeated_subgraph0_1 = getitem = y = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
return (getitem_1,) return (getitem_1,)
class repeated_subgraph0(torch.nn.Module): class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"):
mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None

View File

@ -3621,7 +3621,6 @@ class CompiledAutograd0(torch.nn.Module):
aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None
aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None
aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None
aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None

View File

@ -606,29 +606,31 @@ class CodeGen:
else: else:
body.append("\n") body.append("\n")
prev_stacktrace = None prev_summary_str = None
def append_stacktrace_summary(node: Node): def append_stacktrace_summary(node: Node):
""" """
Append a summary of the stacktrace to the generated code. This is Append a summary of the stacktrace to the generated code. This is
useful for debugging. useful for debugging.
""" """
nonlocal prev_stacktrace nonlocal prev_summary_str
if node.op not in {"placeholder", "output"}: if node.op not in {"placeholder", "output"}:
stack_trace = node.stack_trace annotation_str = ""
if stack_trace: annotation = node.meta.get("custom", {})
if stack_trace != prev_stacktrace: if annotation:
prev_stacktrace = stack_trace annotation_str = f" Annotation: {annotation}"
if parsed_stack_trace := _parse_stack_trace(stack_trace):
summary_str = parsed_stack_trace.get_summary_str() stack_trace_str = "No stacktrace found for following nodes"
else: if stack_trace := node.stack_trace:
summary_str = "" if parsed_stack_trace := _parse_stack_trace(stack_trace):
body.append(f"\n {dim(f'# {summary_str}')}\n") stack_trace_str = parsed_stack_trace.get_summary_str()
elif prev_stacktrace != "":
prev_stacktrace = "" summary_str = f"\n{dim(f'#{annotation_str} {stack_trace_str}')}\n"
no_stacktrace_msg = "# No stacktrace found for following nodes"
body.append(f"\n{dim(no_stacktrace_msg)}\n") if summary_str != prev_summary_str:
prev_summary_str = summary_str
body.append(summary_str)
def stringify_shape(shape: Iterable) -> str: def stringify_shape(shape: Iterable) -> str:
return f"[{', '.join([str(x) for x in shape])}]" return f"[{', '.join([str(x) for x in shape])}]"