mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] add discard_graph_changes to remove the empty calls before hop (#140334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140334 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
eecc8e362c
commit
45bc9165fe
@ -2414,8 +2414,6 @@ def forward(self, fct_1, init_1, xs_1):
|
||||
add_1 = torch.ops.aten.add.Tensor(init_1, select); select = add_1 = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
|
||||
clone = torch.ops.aten.clone.default(init_1); clone = None
|
||||
select_copy = torch.ops.aten.select_copy.int(xs_1, 0, 0); select_copy = None
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
|
||||
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2)
|
||||
scan_combine_graph_0 = self.scan_combine_graph_0
|
||||
@ -2439,8 +2437,6 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
||||
select = l_xs_.select(0, 0)
|
||||
new_carry = l_init_ + select; new_carry = None
|
||||
add_1 = l_init_ + select; select = add_1 = None
|
||||
child = l_init_.clone(); child = None
|
||||
child_1 = torch.select_copy(l_xs_, 0, 0); child_1 = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, True, []); scan_combine_fn_0 = l_init_ = l_xs_ = None
|
||||
getitem = scan[0]
|
||||
@ -5932,8 +5928,6 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
||||
r_2 = r_1.matmul(r); r_1 = r = None
|
||||
r_3 = r_2.add(l_add_closure_0_cell_contents_1_0_); r_2 = None
|
||||
r_4 = r_3.sum(); r_3 = r_4 = None
|
||||
r_5 = l_init_.clone(); r_5 = None
|
||||
r_6 = torch.select_copy(l_xs_, 0, 0); r_6 = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
getitem = scan[0]
|
||||
@ -5955,8 +5949,6 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
||||
matmul_1 = matmul @ select; matmul = select = None
|
||||
ret = matmul_1 + l_add_closure_0_cell_contents_1_0_; matmul_1 = None
|
||||
sum_1 = ret.sum(); ret = sum_1 = None
|
||||
child = l_init_.clone(); child = None
|
||||
child_1 = torch.select_copy(l_xs_, 0, 0); child_1 = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
getitem = scan[0]
|
||||
|
@ -3197,8 +3197,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].score_mod_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, child_4 : torch.Tensor, child_5 : torch.Tensor, child_6 : torch.Tensor, child_7 : torch.Tensor, child_8 : torch.Tensor, getitem : torch.SymInt):
|
||||
add = child_4 + getitem; child_4 = getitem = None
|
||||
def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch.Tensor, child_3 : torch.Tensor, child_4 : torch.Tensor, getitem : torch.SymInt):
|
||||
add = child + getitem; child = getitem = None
|
||||
return add""",
|
||||
)
|
||||
|
||||
@ -3244,16 +3244,7 @@ class GraphModule(torch.nn.Module):
|
||||
l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks
|
||||
l_block_mask_full_q_indices = L_block_mask_full_q_indices
|
||||
|
||||
child_1: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_1 = None
|
||||
child_2: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_2 = None
|
||||
child_3: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_3 = None
|
||||
child_4: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_4 = None
|
||||
child: "f64[]" = l_query_.new_empty([], requires_grad = True); child = None
|
||||
score_mod_0 = self.score_mod_0
|
||||
child_5: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_5 = None
|
||||
child_6: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_6 = None
|
||||
child_7: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_7 = None
|
||||
child_8: "i32[]" = l_query_.new_empty([], dtype = torch.int32); child_8 = None
|
||||
mask_fn_0 = self.mask_fn_0
|
||||
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
||||
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
|
||||
@ -3265,8 +3256,8 @@ class GraphModule(torch.nn.Module):
|
||||
return mul
|
||||
|
||||
class mask_fn_0(torch.nn.Module):
|
||||
def forward(self, child_5: "i32[]", child_6: "i32[]", child_7: "i32[]", child_8: "i32[]"):
|
||||
ge: "b8[]" = child_7 >= child_8; child_7 = child_8 = None
|
||||
def forward(self, child: "i32[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]"):
|
||||
ge: "b8[]" = child_2 >= child_3; child_2 = child_3 = None
|
||||
return ge
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
@ -62,6 +62,26 @@ def raise_hard_error_if_graph_break(reason):
|
||||
return deco
|
||||
|
||||
|
||||
# This function is a syntax sugar for creating a dummy new subtracer so that
|
||||
# newly added nodes are added to a separate subgraph in this subtracer instead of affecting
|
||||
# the main graph. This is useful for creating sample inputs for tracing the subgraph.
|
||||
# For example, in FlexAttentionHigherOrderVariable, we want to create several scalars
|
||||
# to trace the score_mod function but we don't want the operators that creates the scalar to
|
||||
# show up in the graph, we could this function to discard the graph changes.
|
||||
# Example usage:
|
||||
# with discard_graph_changes():
|
||||
# sample_input= create_sample_inputs()
|
||||
# speculate_subgraph(tx, f, sample_inputs, {})
|
||||
@contextlib.contextmanager
|
||||
def discard_graph_changes(tx):
|
||||
ctx = tx.output.subtracer("subgraph_wrapper", None)
|
||||
try:
|
||||
ctx.__enter__()
|
||||
yield
|
||||
finally:
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamo_enable_grad(tx: "InstructionTranslator", enable=True):
|
||||
from . import GradModeVariable
|
||||
@ -1189,13 +1209,13 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable)
|
||||
|
||||
# Trace the subgraph
|
||||
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
|
||||
# The sub_args is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0
|
||||
# the sub_args shape will be (4, ).
|
||||
sub_args = [
|
||||
_make_inlined(tx, first_slice_copy)(leaf, dim)
|
||||
for leaf in itertools.chain(xs.items, xs.items)
|
||||
]
|
||||
with discard_graph_changes(tx):
|
||||
sub_args = [
|
||||
_make_inlined(tx, first_slice_copy)(leaf, dim)
|
||||
for leaf in itertools.chain(xs.items, xs.items)
|
||||
]
|
||||
(
|
||||
(combine_result, combine_treespec),
|
||||
combine_graph,
|
||||
@ -1313,20 +1333,19 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
unimplemented("scan() operator requires init leaves.")
|
||||
|
||||
# Trace the subgraph
|
||||
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
|
||||
# TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc.
|
||||
sub_args_init = [
|
||||
ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init.items
|
||||
]
|
||||
# The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0
|
||||
# the sub_args_inp shape will be (4, ).
|
||||
sub_args_inp = [
|
||||
_make_inlined(tx, first_slice_copy)(inp, dim) for inp in xs.items
|
||||
]
|
||||
sub_args_additional_inputs = [
|
||||
t.call_method(tx, "clone", args=(), kwargs={})
|
||||
for t in additional_inputs.items
|
||||
]
|
||||
with discard_graph_changes(tx):
|
||||
sub_args_init = [
|
||||
ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init.items
|
||||
]
|
||||
# The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0
|
||||
# the sub_args_inp shape will be (4, ).
|
||||
sub_args_inp = [
|
||||
_make_inlined(tx, first_slice_copy)(inp, dim) for inp in xs.items
|
||||
]
|
||||
sub_args_additional_inputs = [
|
||||
t.call_method(tx, "clone", args=(), kwargs={})
|
||||
for t in additional_inputs.items
|
||||
]
|
||||
sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs
|
||||
(
|
||||
(combine_result, combine_treespec),
|
||||
@ -1460,9 +1479,10 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
# To get the example output from map() we will need to provide at least one sample to
|
||||
# the loop body. In our case we will always use xs[0], and our map() won't support zero
|
||||
# sized tensor during tracing.
|
||||
first_dim = wrap_fx_proxy_cls(
|
||||
target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0]
|
||||
)
|
||||
with discard_graph_changes(tx):
|
||||
first_dim = wrap_fx_proxy_cls(
|
||||
target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0]
|
||||
)
|
||||
|
||||
# TODO: Support kwargs
|
||||
(
|
||||
@ -2210,19 +2230,20 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
},
|
||||
)
|
||||
|
||||
bhmn = [create_scalar() for _ in range(4)]
|
||||
if fn_name == "score_mod":
|
||||
scores_require_grad: bool = query.requires_grad
|
||||
score = query.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
(VariableTracker.build(tx, []),),
|
||||
{"requires_grad": VariableTracker.build(tx, scores_require_grad)},
|
||||
)
|
||||
new_args = [score, *bhmn]
|
||||
else:
|
||||
assert fn_name == "mask_fn", "Illegal function name: " + fn_name
|
||||
new_args = [*bhmn]
|
||||
with discard_graph_changes(tx):
|
||||
bhmn = [create_scalar() for _ in range(4)]
|
||||
if fn_name == "score_mod":
|
||||
scores_require_grad: bool = query.requires_grad
|
||||
score = query.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
(VariableTracker.build(tx, []),),
|
||||
{"requires_grad": VariableTracker.build(tx, scores_require_grad)},
|
||||
)
|
||||
new_args = [score, *bhmn]
|
||||
else:
|
||||
assert fn_name == "mask_fn", "Illegal function name: " + fn_name
|
||||
new_args = [*bhmn]
|
||||
|
||||
with TransformGetItemToIndex():
|
||||
(
|
||||
|
Reference in New Issue
Block a user