mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add optional recomputable_ops argument for the min cut partitioner (#86686)
`min_cut_rematerialization_partition` has a default set of hard-coded operations that are allowed to be recomputed in the backward pass. This PR adds customization ability to this function allowing users to control the behavior by passing `recomputable_ops` instead of relying on the default setting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86686 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
fd80684784
commit
2cfc4cb367
@ -231,7 +231,7 @@ def _count_ops(graph):
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser"
|
||||
joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None,
|
||||
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
||||
"""
|
||||
Partitions the joint graph such that the backward recomputes the forward.
|
||||
@ -247,6 +247,12 @@ def min_cut_rematerialization_partition(
|
||||
Args:
|
||||
joint_module(fx.GraphModule): The joint forward and backward graph. This
|
||||
is the result of AOT Autograd tracing.
|
||||
_joint_inputs: The inputs to the joint graph. This is unused.
|
||||
compiler: This option determines the default set of recomputable ops.
|
||||
Currently, there are two options: ``nvfuser`` and ``inductor``.
|
||||
recomputable_ops: This is an optional set of recomputable ops. If this
|
||||
is not None, then this set of ops will be used instead of the
|
||||
default set of ops.
|
||||
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
@ -299,13 +305,14 @@ def min_cut_rematerialization_partition(
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
|
||||
# compiler == "nvfuser" is the default set of recomputable ops
|
||||
default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
|
||||
if compiler == "inductor":
|
||||
recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501
|
||||
default_recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501
|
||||
# Natalia said that we should allow recomputing indexing :)
|
||||
recomputable_ops += [aten.index]
|
||||
default_recomputable_ops += [aten.index]
|
||||
|
||||
recomputable_ops = set(recomputable_ops)
|
||||
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
|
||||
|
||||
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
||||
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward] # noqa: E501
|
||||
|
@ -757,6 +757,62 @@ class TestPartitioning(AOTTestCase):
|
||||
ins, outs = get_ins_outs(fw_graph)
|
||||
self.assertEqual(outs[1].target, torch.ops.aten.mm.default)
|
||||
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
def test_min_cut_partitioner_recomputable_ops(self):
|
||||
def f(x):
|
||||
return x * x * x
|
||||
|
||||
recomputable_ops = []
|
||||
partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops)
|
||||
|
||||
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn)
|
||||
# Expected forward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- --------- --------------- -------------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
||||
# output output output ([mul_1, primals_1, mul],) {}
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
|
||||
# Expected backward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- ---------- --------------- ----------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# placeholder mul mul () {}
|
||||
# placeholder tangents_1 tangents_1 () {}
|
||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
||||
# output output output ([add_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
|
||||
|
||||
recomputable_ops = [torch.ops.aten.mul]
|
||||
partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops)
|
||||
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn)
|
||||
# Expected forward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- --------- --------------- ---------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
||||
# output output output ([mul_1, primals_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
|
||||
# Expected backward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- ---------- --------------- ----------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# placeholder tangents_1 tangents_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
|
||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
||||
# output output output ([add_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
|
||||
|
||||
def test_contiguous(self):
|
||||
# The test simulates the condition where transpose followed by view
|
||||
# happens in the backward pass.
|
||||
|
Reference in New Issue
Block a user