mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add OpOverload.decompose API (#83075)
This allows you to directly call into the CompositeImplicitAutograd implementation of an operator, *without* changing any aspects of the dispatcher state. In particular, you can use this to recursively call into a decomposition, dispatching back to your tensor subclass/mode as desired. Hypothetically, we should also make these available in the decompositions dictionary, but I'm leaving this as future work as enumerating these decompositions is annoying (as operators are lazily registered.) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b4dc56c83
commit
988bd0173c
@ -782,8 +782,7 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
@ -795,9 +794,8 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
|
||||
return diagonal_default
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
# Test 2: copy_() with same dtype, different shape
|
||||
@ -810,8 +808,8 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
@ -823,9 +821,9 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
|
||||
return diagonal_default
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
|
||||
return expand_copy_default
|
||||
""")
|
||||
|
||||
# Test 3: copy_() with different dtype, same shape
|
||||
@ -838,10 +836,10 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(_to_copy_default, a_1); _to_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||
self.assertExpectedInline(reinplaced_logs, """\
|
||||
@ -851,9 +849,9 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
|
||||
return diagonal_default
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(_to_copy_default, a_1); a_1 = None
|
||||
return _to_copy_default
|
||||
""") # noqa: B950
|
||||
|
||||
# Test 4: copy_() with different dtype, different shape
|
||||
@ -866,10 +864,11 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||
self.assertExpectedInline(reinplaced_logs, """\
|
||||
@ -879,9 +878,10 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None
|
||||
return diagonal_default
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
|
||||
return expand_copy_default
|
||||
""") # noqa: B950
|
||||
|
||||
def test_expand_symint(self):
|
||||
|
Reference in New Issue
Block a user