From 0b7d9711d4bf14080b9ef2283b8f4c490c2d647c Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 14 Mar 2024 18:55:31 -0700 Subject: [PATCH] [dynamo] Add support for nn.Parameter constructor (part 2) (#120965) This handles the case where the tensor isn't an input. The changes to dynamo tests are cases where we would previously fall back to eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120965 Approved by: https://github.com/yanboliang ghstack dependencies: #121735 --- ...> TestEmbeddingNN.test_embedding_max_norm} | 0 ...stEmbeddingNN.test_embedding_sparse_basic} | 0 ...dingNN.test_embedding_sparse_empty_tensor} | 0 ...ngNN.test_embeddingbag_include_last_offset | 0 .../TestJitGeneratedModule.test_nn_Bilinear | 0 .../TestJitGeneratedModule.test_nn_Embedding | 0 ...dModule.test_nn_EmbeddingBag_discontiguous | 0 ...itGeneratedModule.test_nn_EmbeddingBag_max | 0 ...odule.test_nn_EmbeddingBag_max_padding_idx | 0 ...tGeneratedModule.test_nn_EmbeddingBag_mean | 0 ...dule.test_nn_EmbeddingBag_mean_padding_idx | 0 ...eneratedModule.test_nn_EmbeddingBag_sparse | 0 ...itGeneratedModule.test_nn_EmbeddingBag_sum | 0 ...odule.test_nn_EmbeddingBag_sum_padding_idx | 0 ...atedModule.test_nn_Embedding_discontiguous | 0 ...itGeneratedModule.test_nn_Embedding_sparse | 0 .../TestJitGeneratedModule.test_nn_Linear | 0 ...eneratedModule.test_nn_Linear_no_batch_dim | 0 ...GeneratedModule.test_nn_PReLU_no_batch_dim | 0 .../TestNN.test_ParameterList | 0 .../TestNN.test_bilinear_broadcasting | 0 ...st_layer_norm_grads_with_create_graph_flag | 0 ..._linear_autograd_device_cpu_bias_weightCOO | 0 ...inear_autograd_device_cpu_nobias_weightCOO | 0 .../TestNN.test_linear_broadcasting | 0 .../TestNN.test_module_apply_inplace_op | 0 ...rors_unparametrized_tensor_parametrization | 0 .../TestPruningNN.test_identity_pruning | 0 .../TestPruningNN.test_random_pruning_0perc | 0 ...volutionNN.test_Conv1d_module_same_padding | 0 ...stConvolutionNN.test_Conv2d_backward_twice | 0 ...volutionNN.test_Conv2d_module_same_padding | 0 ...volutionNN.test_Conv3d_module_same_padding | 0 ...N.test_ConvTranspose3d_correct_output_size | 0 .../TestJitGeneratedModule.test_nn_Conv1d | 0 ...odule.test_nn_Conv1d_circular_stride2_pad2 | 0 ...tJitGeneratedModule.test_nn_Conv1d_dilated | 0 ...stJitGeneratedModule.test_nn_Conv1d_groups | 0 ...TestJitGeneratedModule.test_nn_Conv1d_pad1 | 0 ...itGeneratedModule.test_nn_Conv1d_pad1size1 | 0 ...TestJitGeneratedModule.test_nn_Conv1d_pad2 | 0 ...itGeneratedModule.test_nn_Conv1d_pad2size1 | 0 ...JitGeneratedModule.test_nn_Conv1d_pad_same | 0 ...itGeneratedModule.test_nn_Conv1d_pad_same2 | 0 ...atedModule.test_nn_Conv1d_pad_same_dilated | 0 ...itGeneratedModule.test_nn_Conv1d_pad_valid | 0 ...Module.test_nn_Conv1d_reflect_stride2_pad2 | 0 ...dule.test_nn_Conv1d_replicate_stride2_pad2 | 0 ...stJitGeneratedModule.test_nn_Conv1d_stride | 0 ...tGeneratedModule.test_nn_Conv1d_zero_batch | 0 ...edModule.test_nn_Conv1d_zeros_stride2_pad2 | 0 .../TestJitGeneratedModule.test_nn_Conv2d | 0 ...odule.test_nn_Conv2d_circular_stride2_pad2 | 0 ...itGeneratedModule.test_nn_Conv2d_depthwise | 0 ...tedModule.test_nn_Conv2d_depthwise_dilated | 0 ...atedModule.test_nn_Conv2d_depthwise_padded | 0 ...tedModule.test_nn_Conv2d_depthwise_strided | 0 ...e.test_nn_Conv2d_depthwise_with_multiplier | 0 ...tJitGeneratedModule.test_nn_Conv2d_dilated | 0 ...stJitGeneratedModule.test_nn_Conv2d_groups | 0 ...GeneratedModule.test_nn_Conv2d_groups_thnn | 0 ...JitGeneratedModule.test_nn_Conv2d_pad_same | 0 ...atedModule.test_nn_Conv2d_pad_same_dilated | 0 ...itGeneratedModule.test_nn_Conv2d_pad_valid | 0 ...tJitGeneratedModule.test_nn_Conv2d_padding | 0 ...Module.test_nn_Conv2d_reflect_stride2_pad2 | 0 ...dule.test_nn_Conv2d_replicate_stride2_pad2 | 0 ...tJitGeneratedModule.test_nn_Conv2d_strided | 0 ...tGeneratedModule.test_nn_Conv2d_zero_batch | 0 ...edModule.test_nn_Conv2d_zeros_stride2_pad2 | 0 .../TestJitGeneratedModule.test_nn_Conv3d | 0 ...odule.test_nn_Conv3d_circular_stride2_pad2 | 0 ...tJitGeneratedModule.test_nn_Conv3d_dilated | 0 ...ratedModule.test_nn_Conv3d_dilated_strided | 0 ...stJitGeneratedModule.test_nn_Conv3d_groups | 0 ...JitGeneratedModule.test_nn_Conv3d_pad_same | 0 ...atedModule.test_nn_Conv3d_pad_same_dilated | 0 ...itGeneratedModule.test_nn_Conv3d_pad_valid | 0 ...dule.test_nn_Conv3d_replicate_stride2_pad2 | 0 ...stJitGeneratedModule.test_nn_Conv3d_stride | 0 ...eratedModule.test_nn_Conv3d_stride_padding | 0 ...tGeneratedModule.test_nn_Conv3d_zero_batch | 0 ...edModule.test_nn_Conv3d_zeros_stride2_pad2 | 0 ...JitGeneratedModule.test_nn_ConvTranspose1d | 0 ...atedModule.test_nn_ConvTranspose1d_dilated | 0 ...ratedModule.test_nn_ConvTranspose1d_groups | 0 ...JitGeneratedModule.test_nn_ConvTranspose2d | 0 ...ratedModule.test_nn_ConvTranspose2d_groups | 0 ...JitGeneratedModule.test_nn_ConvTranspose3d | 0 ...atedModule.test_nn_ConvTranspose3d_dilated | 0 test/dynamo_skips/TestNN.test_padding_list | 0 .../TestNN.test_vector_to_parameters | 0 test/inductor/test_distributed_patterns.py | 30 +++++++++++ torch/_dynamo/create_parameter_op.py | 50 +++++++++++++++++++ torch/_dynamo/output_graph.py | 23 +++++++++ torch/_dynamo/variables/torch.py | 31 +++++++++++- .../jit_compile_runtime_wrappers.py | 22 +++++++- torch/_inductor/ir.py | 23 +++++++++ torch/_inductor/lowering.py | 7 +++ 99 files changed, 184 insertions(+), 2 deletions(-) rename test/dynamo_expected_failures/{TestTorchTidyProfiler.test_optimizer => TestEmbeddingNN.test_embedding_max_norm} (100%) rename test/dynamo_expected_failures/{TestTorchTidyProfiler.test_optimizer_parameters_adam => TestEmbeddingNN.test_embedding_sparse_basic} (100%) rename test/dynamo_expected_failures/{TestTorchTidyProfiler.test_optimizer_parameters_sgd => TestEmbeddingNN.test_embedding_sparse_empty_tensor} (100%) create mode 100644 test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim create mode 100644 test/dynamo_expected_failures/TestNN.test_ParameterList create mode 100644 test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO create mode 100644 test/dynamo_expected_failures/TestNN.test_linear_broadcasting create mode 100644 test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_identity_pruning create mode 100644 test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc create mode 100644 test/dynamo_skips/TestConvolutionNN.test_Conv1d_module_same_padding create mode 100644 test/dynamo_skips/TestConvolutionNN.test_Conv2d_backward_twice create mode 100644 test/dynamo_skips/TestConvolutionNN.test_Conv2d_module_same_padding create mode 100644 test/dynamo_skips/TestConvolutionNN.test_Conv3d_module_same_padding create mode 100644 test/dynamo_skips/TestConvolutionNN.test_ConvTranspose3d_correct_output_size create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_circular_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_groups create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1size1 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2size1 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_valid create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_reflect_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_replicate_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_stride create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zero_batch create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zeros_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_circular_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_padded create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_strided create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_with_multiplier create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups_thnn create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_valid create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_padding create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_reflect_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_replicate_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_strided create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zero_batch create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zeros_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_circular_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated_strided create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_groups create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_valid create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_replicate_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride_padding create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zero_batch create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zeros_stride2_pad2 create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_dilated create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_groups create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_groups create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d create mode 100644 test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d_dilated create mode 100644 test/dynamo_skips/TestNN.test_padding_list create mode 100644 test/dynamo_skips/TestNN.test_vector_to_parameters create mode 100644 torch/_dynamo/create_parameter_op.py diff --git a/test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm similarity index 100% rename from test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm diff --git a/test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer_parameters_adam b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic similarity index 100% rename from test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer_parameters_adam rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic diff --git a/test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer_parameters_sgd b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor similarity index 100% rename from test/dynamo_expected_failures/TestTorchTidyProfiler.test_optimizer_parameters_sgd rename to test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterList b/test/dynamo_expected_failures/TestNN.test_ParameterList new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestConvolutionNN.test_Conv1d_module_same_padding b/test/dynamo_skips/TestConvolutionNN.test_Conv1d_module_same_padding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestConvolutionNN.test_Conv2d_backward_twice b/test/dynamo_skips/TestConvolutionNN.test_Conv2d_backward_twice new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestConvolutionNN.test_Conv2d_module_same_padding b/test/dynamo_skips/TestConvolutionNN.test_Conv2d_module_same_padding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestConvolutionNN.test_Conv3d_module_same_padding b/test/dynamo_skips/TestConvolutionNN.test_Conv3d_module_same_padding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestConvolutionNN.test_ConvTranspose3d_correct_output_size b/test/dynamo_skips/TestConvolutionNN.test_ConvTranspose3d_correct_output_size new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_circular_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_circular_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_groups b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_groups new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1size1 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1size1 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2size1 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2size1 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_valid b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_valid new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_reflect_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_reflect_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_replicate_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_replicate_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_stride b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_stride new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zero_batch b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zero_batch new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zeros_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zeros_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_circular_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_circular_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_padded b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_padded new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_strided b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_strided new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_with_multiplier b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_with_multiplier new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups_thnn b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups_thnn new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_valid b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_valid new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_padding b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_padding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_reflect_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_reflect_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_replicate_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_replicate_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_strided b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_strided new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zero_batch b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zero_batch new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zeros_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zeros_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_circular_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_circular_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated_strided b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated_strided new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_groups b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_groups new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_valid b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_valid new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_replicate_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_replicate_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride_padding b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride_padding new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zero_batch b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zero_batch new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zeros_stride2_pad2 b/test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zeros_stride2_pad2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_groups b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_groups new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_groups b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_groups new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d_dilated b/test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d_dilated new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestNN.test_padding_list b/test/dynamo_skips/TestNN.test_padding_list new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_skips/TestNN.test_vector_to_parameters b/test/dynamo_skips/TestNN.test_vector_to_parameters new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index c43df62d439a..68c2b729a76a 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -301,6 +301,36 @@ class DistributedPatternTests(TestCase): self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) + def test_nn_param_return3(self): + def fn(x): + p = torch.nn.Parameter(x + 123) + return p, p.sin() + + opt = torch.compile(fn, fullgraph=True) + x1 = torch.randn(16) + x2 = x1.clone() + + p1, r1 = fn(x1) + r1.sum().backward() + p2, r2 = opt(x2) + r2.sum().backward() + self._assert_same_grad(r1, r2) + self._assert_same_grad(p1, p2) + + def test_nn_param_return4(self): + def fn(x): + p = torch.nn.Parameter(x + 123, requires_grad=False) + return p, x + 1 + + opt = torch.compile(fn, fullgraph=True) + x1 = torch.randn(16) + x2 = x1.clone() + + p1, r1 = fn(x1) + p2, r2 = opt(x2) + self._assert_same_grad(r1, r2) + self._assert_same_grad(p1, p2) + if __name__ == "__main__": if HAS_CPU and not IS_MACOS: diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py new file mode 100644 index 000000000000..25ba68b9be76 --- /dev/null +++ b/torch/_dynamo/create_parameter_op.py @@ -0,0 +1,50 @@ +import torch +from torch._prims import _make_prim, RETURN_TYPE +from torch._prims_common import clone_preserve_strides + +doc = """ +This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly +with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which +becomes a graph arg and has no storage backing it. At the point in the graph where the parameter +actually should be created we mutate this sacrificial placeholder into it. This allows gradients +to flow into the parameter as if it were an input to the graph (which is the only thing we are +allowed to compute gradients on). +""".strip() + +_bind_nn_parameter = _make_prim( + schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor", + return_type=RETURN_TYPE.NEW, + meta=lambda self, placeholder: torch.nn.Parameter( + clone_preserve_strides(self), placeholder.requires_grad + ), + impl_aten=lambda self, placeholder: placeholder.set_(self), + doc=doc, +) +torch.fx.node.has_side_effect(_bind_nn_parameter) + + +class TracableCreateParameter(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, placeholder): + assert not tensor.requires_grad + return _bind_nn_parameter(tensor, placeholder) + + @staticmethod + def backward(ctx, grad): + return None, grad # grad flows to placeholder + + +def tracable_create_parameter(tensor, placeholder): + with torch.set_grad_enabled(placeholder.requires_grad): + return TracableCreateParameter.apply(tensor, placeholder) + + +def new_parameter_placeholder(size, dtype, device, requires_grad): + """Create a placeholder to be passed to the above functions""" + result = torch.nn.Parameter( + torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad + ) + # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor. + # Allocating a zero tensor would causes assert failures in autograd. + result.untyped_storage().resize_(0) + return result diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ac5ae44f0c0b..42fa15a14a92 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -69,6 +69,7 @@ from .source import ( LocalSource, ParamBufferSource, ShapeEnvSource, + SyntheticLocalSource, TensorProperty, TensorPropertySource, ) @@ -472,6 +473,28 @@ class OutputGraph(Checkpointable[OutputGraphState]): self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH)) + def synthetic_graph_input(self, fn, args): + """ + call fn(*args) before the graph runs and turn the result into a fake input. + """ + example_value = fn(*args) + varname = self.new_var() + cg = PyCodegen(self.root_tx) + cg.load_import_from( + fn.__module__, + fn.__name__, + ) + cg.foreach(map(variables.ConstantVariable.create, args)) + cg.call_function(len(args), True) + cg.store(varname) + self.pregraph_bytecode.extend(cg.get_instructions()) + source = SyntheticLocalSource(varname) + result = VariableBuilder(self.root_tx, source)(example_value) + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + def add_cleanup_hook(self, fn: Callable[[], Any]): self.cleanup_hooks.append(fn) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index d6835421eb83..243191902091 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -17,6 +17,7 @@ from torch._streambase import _StreamBase from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen +from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -840,7 +841,35 @@ Either create the tensor outside the compiled region, or do not set the tensor t if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - unimplemented("Parameter() on non-input") + try: + shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) + dtype = data.var_getattr(tx, "dtype").as_python_constant() + device = data.var_getattr(tx, "device").as_python_constant() + except NotImplementedError as e: + unimplemented(f"Parameter not python_constant: {e}") + + placeholder = tx.output.synthetic_graph_input( + new_parameter_placeholder, [shape, dtype, device, requires_grad] + ) + if data.requires_grad: + data = data.call_method(tx, "detach", [], {}) + + from .builder import wrap_fx_proxy + + result = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + tracable_create_parameter, + (data.as_proxy(), placeholder.as_proxy()), + {}, + ), + ) + assert isinstance(result, variables.TensorVariable) + result.class_type = torch.nn.Parameter + # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. + result.source = placeholder.source + return result @staticmethod def _nn_param_via_prefix_insert(tx, data, requires_grad): diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 05147e19af48..ce087d44e257 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -181,6 +181,26 @@ def _output_node(gm: torch.fx.GraphModule) -> torch.fx.Node: return next(n for n in reversed(gm.graph.nodes) if n.op == "output") +def _input_node(gm: torch.fx.GraphModule, i: int) -> torch.fx.Node: + """Fetch the i-th placeholder in the graph""" + seen = 0 + for n in gm.graph.nodes: + if n.op == "placeholder": + if seen == i: + return n + seen += 1 + raise IndexError(f"input {i} does not exist, only {seen} inputs in graph") + + +def _can_detach(node: torch.fx.Node): + """ + Avoid calling .detach() on inputs passed to _bind_nn_parameter() + """ + from torch._dynamo.create_parameter_op import _bind_nn_parameter + + return all(n.target is not _bind_nn_parameter for n in node.users) + + def aot_dispatch_autograd( flat_fn, flat_args: List[Any], @@ -317,7 +337,7 @@ def aot_dispatch_autograd( == len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset ) for i, (bw_out) in enumerate(bw_outs): - if bw_out is None: + if bw_out is None and _can_detach(_input_node(fx_g, i)): _indices_of_inps_to_detach.append(i) if aot_config.enable_log: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index bcd168f4d028..95cffbe26a4e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4496,6 +4496,29 @@ class ResizeStorageBytes(MutatingFirstArgExternKernel): mark_node_as_mutating(self, variable) +class BindNNParameter(ExternKernelAlloc): + def __init__(self, variable, placeholder): + variable.freeze_layout() + super().__init__( + variable.get_layout(), + [variable, placeholder], + python_kernel_name="torch.ops.prims._bind_nn_parameter", + ) + V.graph.never_reuse_buffers.add(variable.data.get_name()) + V.graph.never_reuse_buffers.add(placeholder.get_name()) + V.graph.never_reuse_buffers.add(self.get_name()) + mark_node_as_mutating(self, variable, placeholder) + + def get_alias_names(self): + return [self.inputs[0].get_name(), self.inputs[1].get_name()] + + def get_mutation_names(self): + return [self.inputs[1].get_name()] + + def has_side_effects(self): + return True + + class ScatterFallback(ExternKernel): """ This needs to be a custom class to handle mutation properly. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 6264035ef5f9..47e1457f1232 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -13,6 +13,7 @@ import torch import torch.ao.quantization.fx._decomposed import torch.fx import torch.utils._pytree as pytree +from torch._dynamo.create_parameter_op import _bind_nn_parameter from torch._higher_order_ops.triton_kernel_wrap import ( triton_kernel_wrapper_functional, triton_kernel_wrapper_mutation, @@ -5924,6 +5925,12 @@ def resize_storage_bytes_(variable, new_size): return variable +@register_lowering(_bind_nn_parameter) +def create_nn_parameter(self, placeholder): + self.realize() + return TensorBox.create(ir.BindNNParameter(self, placeholder)) + + from torch._higher_order_ops.auto_functionalize import auto_functionalized make_fallback(auto_functionalized)