Add OpOverload.decompose API (#83075)

This allows you to directly call into the CompositeImplicitAutograd
implementation of an operator, *without* changing any aspects of the
dispatcher state.  In particular, you can use this to recursively call
into a decomposition, dispatching back to your tensor subclass/mode
as desired.

Hypothetically, we should also make these available in the
decompositions dictionary, but I'm leaving this as future work as
enumerating these decompositions is annoying (as operators are lazily
registered.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-08-09 08:35:50 -07:00
committed by PyTorch MergeBot
parent 9b4dc56c83
commit 988bd0173c
12 changed files with 138 additions and 36 deletions

View File

@ -782,8 +782,7 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add_tensor
""")
@ -795,9 +794,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
return diagonal_default
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add_tensor
""")
# Test 2: copy_() with same dtype, different shape
@ -810,8 +808,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
return add_tensor
""")
@ -823,9 +821,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
return diagonal_default
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
return expand_copy_default
""")
# Test 3: copy_() with different dtype, same shape
@ -838,10 +836,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(_to_copy_default, a_1); _to_copy_default = a_1 = None
return add_tensor
""")
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
@ -851,9 +849,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
return diagonal_default
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add_.Tensor(_to_copy_default, a_1); a_1 = None
return _to_copy_default
""") # noqa: B950
# Test 4: copy_() with different dtype, different shape
@ -866,10 +864,11 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
return add_tensor
""")
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
@ -879,9 +878,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
return diagonal_default
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
return expand_copy_default
""") # noqa: B950
def test_expand_symint(self):