Add optional recomputable_ops argument for the min cut partitioner (#86686)

`min_cut_rematerialization_partition` has a default set of hard-coded operations that are allowed to be recomputed in the backward pass.
This PR adds customization ability to this function allowing users to control the behavior by passing `recomputable_ops` instead of relying on the default setting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86686
Approved by: https://github.com/Chillee
This commit is contained in:
Ivan Yashchuk
2022-10-14 12:15:28 +00:00
committed by PyTorch MergeBot
parent fd80684784
commit 2cfc4cb367
2 changed files with 68 additions and 5 deletions

View File

@ -231,7 +231,7 @@ def _count_ops(graph):
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser"
joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
"""
Partitions the joint graph such that the backward recomputes the forward.
@ -247,6 +247,12 @@ def min_cut_rematerialization_partition(
Args:
joint_module(fx.GraphModule): The joint forward and backward graph. This
is the result of AOT Autograd tracing.
_joint_inputs: The inputs to the joint graph. This is unused.
compiler: This option determines the default set of recomputable ops.
Currently, there are two options: ``nvfuser`` and ``inductor``.
recomputable_ops: This is an optional set of recomputable ops. If this
is not None, then this set of ops will be used instead of the
default set of ops.
Returns:
Returns the generated forward and backward Fx graph modules.
@ -299,13 +305,14 @@ def min_cut_rematerialization_partition(
aten = torch.ops.aten
prims = torch.ops.prims
recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
# compiler == "nvfuser" is the default set of recomputable ops
default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
if compiler == "inductor":
recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501
default_recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501
# Natalia said that we should allow recomputing indexing :)
recomputable_ops += [aten.index]
default_recomputable_ops += [aten.index]
recomputable_ops = set(recomputable_ops)
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward] # noqa: E501

View File

@ -757,6 +757,62 @@ class TestPartitioning(AOTTestCase):
ins, outs = get_ins_outs(fw_graph)
self.assertEqual(outs[1].target, torch.ops.aten.mm.default)
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_min_cut_partitioner_recomputable_ops(self):
def f(x):
return x * x * x
recomputable_ops = []
partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops)
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- -------------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1, mul],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder mul mul () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
recomputable_ops = [torch.ops.aten.mul]
partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops)
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- ---------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
def test_contiguous(self):
# The test simulates the condition where transpose followed by view
# happens in the backward pass.