diff --git a/docs/source/export.ir_spec.rst b/docs/source/export.ir_spec.rst index fb43ea847c86..dadbd8d0c6ea 100644 --- a/docs/source/export.ir_spec.rst +++ b/docs/source/export.ir_spec.rst @@ -103,15 +103,25 @@ of the Graph of GraphModule. Example:: - from torch import nn + import torch + from torch import nn - class MyModule(nn.Module): + class MyModule(nn.Module): - def forward(self, x, y): - return x + y + def forward(self, x, y): + return x + y - mod = torch.export.export(MyModule()) - print(mod.graph) + example_args = (torch.randn(1), torch.randn(1)) + mod = torch.export.export(MyModule(), example_args) + print(mod.graph) + +.. code-block:: python + + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) + return (add,) The above is the textual representation of a Graph, with each line being a node. diff --git a/docs/source/export.rst b/docs/source/export.rst index 6d6784c97c52..9654d789b8bc 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -39,28 +39,41 @@ serialized. ExportedProgram: class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]): + def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): # code: a = torch.sin(x) - sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1); + sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) # code: b = torch.cos(y) - cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1); + cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) # code: return a + b - add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos); + add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) return (add,) - Graph signature: ExportGraphSignature( - parameters=[], - buffers=[], - user_inputs=['arg0_1', 'arg1_1'], - user_outputs=['add'], - inputs_to_parameters={}, - inputs_to_buffers={}, - buffers_to_mutate={}, - backward_signature=None, - assertion_dep_token=None, - ) + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='y'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target=None + ) + ] + ) Range constraints: {} ``torch.export`` produces a clean intermediate representation (IR) with the @@ -183,39 +196,55 @@ example: ExportedProgram: class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]): - + def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): # code: a = self.conv(x) - convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default( - arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 - ); + conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) # code: a.add_(constant) - add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1); + add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) # code: return self.maxpool(self.relu(a)) - relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add); - max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default( - relu, [3, 3], [3, 3] - ); - getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0]; - return (getitem,) + relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) + max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) + return (max_pool2d,) - Graph signature: ExportGraphSignature( - parameters=['L__self___conv.weight', 'L__self___conv.bias'], - buffers=[], - user_inputs=['arg2_1', 'arg3_1'], - user_outputs=['getitem'], - inputs_to_parameters={ - 'arg0_1': 'L__self___conv.weight', - 'arg1_1': 'L__self___conv.bias', - }, - inputs_to_buffers={}, - buffers_to_mutate={}, - backward_signature=None, - assertion_dep_token=None, + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='constant'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='max_pool2d'), + target=None + ) + ] ) - Range constraints: {} + Range constraints: {} Inspecting the ``ExportedProgram``, we can note the following: @@ -336,25 +365,69 @@ To show some examples: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None - add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None - batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) + batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) return (batch_norm,) Graph signature: ExportGraphSignature( input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_weight'), + target='bn.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_bias'), + target='bn.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_mean'), + target='bn.running_mean', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_var'), + target='bn.running_var', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_num_batches_tracked'), + target='bn.num_batches_tracked', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ) ], output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='batch_norm'), target=None) + OutputSpec( + kind=, + arg=TensorArgument(name='batch_norm'), + target=None + ) ] ) Range constraints: {} @@ -380,36 +453,93 @@ You can also go from this IR to an inference IR via :func:`run_decompositions` w ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] return (getitem_3, getitem_4, add, getitem) - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), - OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), - OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), - OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) - ] - ) + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_weight'), + target='bn.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_bias'), + target='bn.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_mean'), + target='bn.running_mean', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_var'), + target='bn.running_var', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_num_batches_tracked'), + target='bn.num_batches_tracked', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='getitem_3'), + target='bn.running_mean' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='getitem_4'), + target='bn.running_var' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target='bn.num_batches_tracked' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='getitem'), + target=None + ) + ] + ) Range constraints: {} -Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR -containing core aten operators except for `conv2d`. +Here you can see that we kept ``conv2d`` op in the IR while decomposing the rest. Now the IR is a functional IR +containing core aten operators except for ``conv2d``. You can do even more customization by directly registering your chosen decomposition behaviors. @@ -433,32 +563,89 @@ You can do even more customizations by directly registering custom decomp behavi ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) + mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; return (getitem_3, getitem_4, add, getitem) - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), - InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), - OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), - OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), - OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) - ] + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_weight'), + target='bn.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_bn_bias'), + target='bn.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_mean'), + target='bn.running_mean', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_running_var'), + target='bn.running_var', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='b_bn_num_batches_tracked'), + target='bn.num_batches_tracked', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='getitem_3'), + target='bn.running_mean' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='getitem_4'), + target='bn.running_var' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target='bn.num_batches_tracked' + ), + OutputSpec( + kind=, + arg=TensorArgument(name='getitem'), + target=None + ) + ] ) Range constraints: {} @@ -510,58 +697,94 @@ run. Such dimensions must be specified by using the .. code-block:: ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]): + class GraphModule(torch.nn.Module): + def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): - # code: out1 = self.branch1(x1) - permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]); - addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute); - relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm); + # code: out1 = self.branch1(x1) + linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) + relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) - # code: out2 = self.branch2(x2) - permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]); - addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1); - relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None + # code: out2 = self.branch2(x2) + linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) + relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) - # code: return (out1 + self.buffer, out2) - add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1); - return (add, relu_1) + # code: return (out1 + self.buffer, out2) + add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) + return (add, relu_1) - Graph signature: ExportGraphSignature( - parameters=[ - 'branch1.0.weight', - 'branch1.0.bias', - 'branch2.0.weight', - 'branch2.0.bias', + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_branch1_0_weight'), + target='branch1.0.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_branch1_0_bias'), + target='branch1.0.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_branch2_0_weight'), + target='branch2.0.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_branch2_0_bias'), + target='branch2.0.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='c_buffer'), + target='buffer', + persistent=True + ), + InputSpec( + kind=, + arg=TensorArgument(name='x1'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='x2'), + target=None, + persistent=None + ) ], - buffers=['L__self___buffer'], - user_inputs=['arg5_1', 'arg6_1'], - user_outputs=['add', 'relu_1'], - inputs_to_parameters={ - 'arg0_1': 'branch1.0.weight', - 'arg1_1': 'branch1.0.bias', - 'arg2_1': 'branch2.0.weight', - 'arg3_1': 'branch2.0.bias', - }, - inputs_to_buffers={'arg4_1': 'L__self___buffer'}, - buffers_to_mutate={}, - backward_signature=None, - assertion_dep_token=None, + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target=None + ), + OutputSpec( + kind=, + arg=TensorArgument(name='relu_1'), + target=None + ) + ] ) - Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)} + Range constraints: {s0: VR[0, int_oo]} Some additional things to note: * Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first - dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and - ``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of + dimension of each input to be dynamic. Looking at the inputs ``x1`` and + ``x2``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. ``s0`` is a symbol representing that this dimension can be a range of values. * ``exported_program.range_constraints`` describes the ranges of each symbol appearing in the graph. In this case, we see that ``s0`` has the range - [2, inf]. For technical reasons that are difficult to explain here, they are + [0, int_oo]. For technical reasons that are difficult to explain here, they are assumed to be not 0 or 1. This is not a bug, and does not necessarily mean that the exported program will not work for dimensions 0 or 1. See `The 0/1 Specialization Problem `_ @@ -591,21 +814,37 @@ another, or a shape is even. An example: ExportedProgram: class GraphModule(torch.nn.Module): - def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"): + def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): # code: return x + y[1:] - slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None - add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None + slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) + add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) return (add,) - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)] - ) - Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)} + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='y'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target=None + ) + ] + ) + Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} Some things to note: @@ -613,8 +852,8 @@ Some things to note: shape of the first input is now dynamic, being ``[s0]``. And now by specifying ``{0: dimy}`` for the second input, we see that the resulting shape of the second input is also dynamic. However, because we expressed ``dimy = dimx + 1``, - instead of ``arg1_1``'s shape containing a new symbol, we see that it is - now being represented with the same symbol used in ``arg0_1``, ``s0``. We can + instead of ``y``'s shape containing a new symbol, we see that it is + now being represented with the same symbol used in ``x``, ``s0``. We can see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``. * Looking at the range constraints, we see that ``s0`` has the range [3, 6], @@ -700,10 +939,11 @@ that is being taken with the given sample inputs. For example: .. code-block:: ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[10, 2]): - add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); - return (add,) + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[10, 2]"): + # code: return x + 1 + add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) + return (add,) The conditional of (``x.shape[0] > 5``) does not appear in the ``ExportedProgram`` because the example inputs have the static @@ -745,19 +985,20 @@ For example: ExportedProgram: class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1): - add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); - add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1); - add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1); + def forward(self, x: "f32[2, 2]", const, times): + # code: x = x + const + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) return (add_2,) Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations -are all computed with the hard-coded constant ``1``, rather than ``arg1_1``. If -a user passes a different value for ``arg1_1`` at runtime, like 2, than the one used +are all computed with the hard-coded constant ``1``, rather than ``const``. If +a user passes a different value for ``const`` at runtime, like 2, than the one used during export time, 1, this will result in an error. Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined" in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the -input ``arg2_1`` is never used. +input ``times`` is never used. Python Containers ~~~~~~~~~~~~~~~~~ diff --git a/torch/export/__init__.py b/torch/export/__init__.py index dbe4f2b72ed2..fc890c78ecc7 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -589,22 +589,27 @@ def register_dataclass( Example:: + import torch + from dataclasses import dataclass + @dataclass class InputDataClass: feature: torch.Tensor bias: int + @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) - def fn(o: InputDataClass) -> torch.Tensor: - res = res=o.feature + o.bias - return OutputDataClass(res=res) + class Mod(torch.nn.Module): + def forward(self, x: InputDataClass) -> OutputDataClass: + res = x.feature + x.bias + return OutputDataClass(res=res) - ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) + ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep) """