mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support calling torch.compile inside non-strict export (#164171)
So this fixes at least two issues: 1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops. 2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally. So the things i did for fixing above were: 1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared. 2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked. 3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode. With above fixes, we are able to export flex attention in export. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164171 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
3288fbf374
commit
2a11ce2c78
@ -717,6 +717,248 @@ class TestExport(TestCase):
|
||||
)
|
||||
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
|
||||
|
||||
@requires_gpu
|
||||
def test_flex_attention_export(self):
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
class MixedFakeModeModel(torch.nn.Module):
|
||||
def __init__(self, dim=64, use_inductor=True):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.q_proj = torch.nn.Linear(64, 64)
|
||||
self.k_proj = torch.nn.Linear(64, 64)
|
||||
self.v_proj = torch.nn.Linear(64, 64)
|
||||
self.use_inductor = use_inductor
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# Process input first - this creates fake tensors in export's fake mode
|
||||
processed = self.q_proj(x)
|
||||
|
||||
# Create some computation that depends on processed tensor
|
||||
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
|
||||
|
||||
# Now call create_block_mask which internally calls torch.compile
|
||||
# The mask function will capture 'intermediate' which is a fake tensor
|
||||
# from export's fake mode, but create_block_mask will create its own fake mode
|
||||
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
|
||||
# This captures the intermediate tensor from the outer scope
|
||||
# When torch.compile is called inside create_block_mask,
|
||||
# this tensor will be from export's fake mode while new tensors
|
||||
# created inside will be from the nested fake mode
|
||||
threshold = intermediate[
|
||||
batch_idx, q_idx % seq_len
|
||||
] # Access the captured tensor
|
||||
return (kv_idx <= q_idx) & (threshold > 0) # Mix fake modes
|
||||
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=dynamic_mask_function,
|
||||
B=batch_size,
|
||||
H=None,
|
||||
Q_LEN=seq_len,
|
||||
KV_LEN=seq_len,
|
||||
device=x.device,
|
||||
)
|
||||
q = self.q_proj(processed).view(batch_size, 1, seq_len, self.dim)
|
||||
k = self.k_proj(processed).view(batch_size, 1, seq_len, self.dim)
|
||||
v = self.v_proj(processed).view(batch_size, 1, seq_len, self.dim)
|
||||
|
||||
# Use flex_attention with the problematic block_mask
|
||||
backend = "inductor" if self.use_inductor else "eager"
|
||||
out = torch.compile(flex_attention, backend=backend)(
|
||||
q, k, v, block_mask=block_mask
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
model = MixedFakeModeModel(use_inductor=False)
|
||||
x = torch.randn(2, 128, 64)
|
||||
# Inductor doesn't work in eager mode flex attention
|
||||
eager_out = model(x)
|
||||
model.use_inductor = True
|
||||
exported_mod = torch.export.export(model, (x,), strict=False).module()
|
||||
self.assertExpectedInline(
|
||||
str(exported_mod.code).strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
q_proj_weight = self.q_proj.weight
|
||||
q_proj_bias = self.q_proj.bias
|
||||
k_proj_weight = self.k_proj.weight
|
||||
k_proj_bias = self.k_proj.bias
|
||||
v_proj_weight = self.v_proj.weight
|
||||
v_proj_bias = self.v_proj.bias
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
linear = torch.ops.aten.linear.default(x, q_proj_weight, q_proj_bias); x = None
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(linear, [-1])
|
||||
detach = torch.ops.aten.detach.default(sum_1); sum_1 = None
|
||||
arange = torch.ops.aten.arange.start(0, 2, device = device(type='cpu'), pin_memory = False)
|
||||
arange_1 = torch.ops.aten.arange.start(0, 1, device = device(type='cpu'), pin_memory = False)
|
||||
arange_2 = torch.ops.aten.arange.start(0, 128, device = device(type='cpu'), pin_memory = False)
|
||||
arange_3 = torch.ops.aten.arange.start(0, 128, device = device(type='cpu'), pin_memory = False)
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
_add_batch_dim = torch._functorch.predispatch._add_batch_dim(arange, 0, 1); arange = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
|
||||
_add_batch_dim_1 = torch._functorch.predispatch._add_batch_dim(arange_1, 0, 2); arange_1 = _add_batch_dim_1 = None
|
||||
lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_2 = None
|
||||
_vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(128, 'error'); _vmap_increment_nesting_2 = None
|
||||
_add_batch_dim_2 = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 3); arange_2 = None
|
||||
lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_3 = None
|
||||
_vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(128, 'error'); _vmap_increment_nesting_3 = None
|
||||
_add_batch_dim_3 = torch._functorch.predispatch._add_batch_dim(arange_3, 0, 4); arange_3 = None
|
||||
remainder = torch.ops.aten.remainder.Scalar(_add_batch_dim_2, 128)
|
||||
torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0
|
||||
function_const_func_spec0 = self.function_const_func_spec0
|
||||
flat_apply = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', detach, _add_batch_dim, remainder); function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = _add_batch_dim = remainder = None
|
||||
le = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2); _add_batch_dim_3 = _add_batch_dim_2 = None
|
||||
gt = torch.ops.aten.gt.Scalar(flat_apply, 0); flat_apply = None
|
||||
and_1 = torch.ops.aten.__and__.Tensor(le, gt); le = gt = None
|
||||
_remove_batch_dim = torch._functorch.predispatch._remove_batch_dim(and_1, 4, 128, 0); and_1 = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_remove_batch_dim_1 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 128, 0); _remove_batch_dim = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_remove_batch_dim_2 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0)
|
||||
expand = torch.ops.aten.expand.default(_remove_batch_dim_1, [1, 128, 128]); _remove_batch_dim_1 = expand = None
|
||||
_vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
|
||||
_remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0); _remove_batch_dim_2 = None
|
||||
_vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
|
||||
pad = torch.ops.aten.pad.default(_remove_batch_dim_3, [0, 0, 0, 0]); _remove_batch_dim_3 = None
|
||||
view = torch.ops.aten.view.default(pad, [2, 1, 1, 128, 1, 128]); pad = None
|
||||
permute = torch.ops.aten.permute.default(view, [0, 1, 2, 4, 3, 5]); view = None
|
||||
sum_2 = torch.ops.aten.sum.dim_IntList(permute, [-2, -1]); permute = None
|
||||
eq = torch.ops.aten.eq.Scalar(sum_2, 16384)
|
||||
gt_1 = torch.ops.aten.gt.Scalar(sum_2, 0)
|
||||
lt = torch.ops.aten.lt.Scalar(sum_2, 16384); sum_2 = None
|
||||
and_2 = torch.ops.aten.__and__.Tensor(gt_1, lt); gt_1 = lt = None
|
||||
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(and_2, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
||||
to = torch.ops.aten.to.dtype(and_2, torch.int8); and_2 = None
|
||||
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(eq, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
|
||||
to_1 = torch.ops.aten.to.dtype(eq, torch.int8); eq = None
|
||||
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(to, dtype = torch.int8, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
|
||||
to_2 = torch.ops.aten.to.dtype(to, torch.int32); to = None
|
||||
sum_3 = torch.ops.aten.sum.dim_IntList(to_2, [-1])
|
||||
argsort = torch.ops.aten.argsort.stable(to_2, stable = True, descending = True); to_2 = None
|
||||
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(sum_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
|
||||
to_3 = torch.ops.aten.to.dtype(sum_3, torch.int32, False, False, torch.contiguous_format); sum_3 = None
|
||||
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(argsort, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
|
||||
to_4 = torch.ops.aten.to.dtype(argsort, torch.int32, False, False, torch.contiguous_format); argsort = None
|
||||
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_1, dtype = torch.int8, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
|
||||
to_5 = torch.ops.aten.to.dtype(to_1, torch.int32); to_1 = None
|
||||
sum_4 = torch.ops.aten.sum.dim_IntList(to_5, [-1])
|
||||
argsort_1 = torch.ops.aten.argsort.stable(to_5, stable = True, descending = True); to_5 = None
|
||||
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(sum_4, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
|
||||
to_6 = torch.ops.aten.to.dtype(sum_4, torch.int32, False, False, torch.contiguous_format); sum_4 = None
|
||||
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(argsort_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
|
||||
to_7 = torch.ops.aten.to.dtype(argsort_1, torch.int32, False, False, torch.contiguous_format); argsort_1 = None
|
||||
lazy_load_decompositions_4 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_4 = None
|
||||
_vmap_increment_nesting_4 = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_4 = None
|
||||
_add_batch_dim_4 = torch._functorch.predispatch._add_batch_dim(to_3, 0, 1)
|
||||
_add_batch_dim_5 = torch._functorch.predispatch._add_batch_dim(to_4, 0, 1)
|
||||
lazy_load_decompositions_5 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_5 = None
|
||||
_vmap_increment_nesting_5 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_5 = None
|
||||
_add_batch_dim_6 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_4, 0, 2); _add_batch_dim_4 = None
|
||||
_add_batch_dim_7 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_5, 0, 2); _add_batch_dim_5 = None
|
||||
new_zeros = torch.ops.aten.new_zeros.default(_add_batch_dim_7, [1, 2], dtype = torch.int32, pin_memory = False)
|
||||
arange_4 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(arange_4, -1); arange_4 = None
|
||||
arange_5 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(_add_batch_dim_6, -1); _add_batch_dim_6 = None
|
||||
lt_1 = torch.ops.aten.lt.Tensor(arange_5, unsqueeze_1); arange_5 = unsqueeze_1 = None
|
||||
where = torch.ops.aten.where.ScalarOther(lt_1, _add_batch_dim_7, 1); lt_1 = _add_batch_dim_7 = None
|
||||
new_ones = torch.ops.aten.new_ones.default(new_zeros, [], pin_memory = False)
|
||||
index_put_ = torch.ops.aten.index_put_.default(new_zeros, [unsqueeze, where], new_ones); new_zeros = unsqueeze = where = new_ones = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(index_put_, 1, 0, 1); index_put_ = None
|
||||
_remove_batch_dim_4 = torch._functorch.predispatch._remove_batch_dim(slice_1, 2, 1, 0); slice_1 = None
|
||||
_vmap_decrement_nesting_4 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_4 = None
|
||||
_remove_batch_dim_5 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_4, 1, 2, 0); _remove_batch_dim_4 = None
|
||||
_vmap_decrement_nesting_5 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_5 = None
|
||||
transpose = torch.ops.aten.transpose.int(_remove_batch_dim_5, -2, -1); _remove_batch_dim_5 = None
|
||||
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(transpose, dtype = torch.int32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
|
||||
to_8 = torch.ops.aten.to.dtype(transpose, torch.int32); transpose = None
|
||||
sum_5 = torch.ops.aten.sum.dim_IntList(to_8, [-1])
|
||||
argsort_2 = torch.ops.aten.argsort.stable(to_8, stable = True, descending = True); to_8 = None
|
||||
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(sum_5, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
|
||||
to_9 = torch.ops.aten.to.dtype(sum_5, torch.int32, False, False, torch.contiguous_format); sum_5 = None
|
||||
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(argsort_2, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
|
||||
to_10 = torch.ops.aten.to.dtype(argsort_2, torch.int32, False, False, torch.contiguous_format); argsort_2 = None
|
||||
lazy_load_decompositions_6 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_6 = None
|
||||
_vmap_increment_nesting_6 = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_6 = None
|
||||
_add_batch_dim_8 = torch._functorch.predispatch._add_batch_dim(to_6, 0, 1)
|
||||
_add_batch_dim_9 = torch._functorch.predispatch._add_batch_dim(to_7, 0, 1)
|
||||
lazy_load_decompositions_7 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_7 = None
|
||||
_vmap_increment_nesting_7 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_7 = None
|
||||
_add_batch_dim_10 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_8, 0, 2); _add_batch_dim_8 = None
|
||||
_add_batch_dim_11 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_9, 0, 2); _add_batch_dim_9 = None
|
||||
new_zeros_1 = torch.ops.aten.new_zeros.default(_add_batch_dim_11, [1, 2], dtype = torch.int32, pin_memory = False)
|
||||
arange_6 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
|
||||
unsqueeze_2 = torch.ops.aten.unsqueeze.default(arange_6, -1); arange_6 = None
|
||||
arange_7 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
|
||||
unsqueeze_3 = torch.ops.aten.unsqueeze.default(_add_batch_dim_10, -1); _add_batch_dim_10 = None
|
||||
lt_2 = torch.ops.aten.lt.Tensor(arange_7, unsqueeze_3); arange_7 = unsqueeze_3 = None
|
||||
where_1 = torch.ops.aten.where.ScalarOther(lt_2, _add_batch_dim_11, 1); lt_2 = _add_batch_dim_11 = None
|
||||
new_ones_1 = torch.ops.aten.new_ones.default(new_zeros_1, [], pin_memory = False)
|
||||
index_put__1 = torch.ops.aten.index_put_.default(new_zeros_1, [unsqueeze_2, where_1], new_ones_1); new_zeros_1 = unsqueeze_2 = where_1 = new_ones_1 = None
|
||||
slice_2 = torch.ops.aten.slice.Tensor(index_put__1, 1, 0, 1); index_put__1 = None
|
||||
_remove_batch_dim_6 = torch._functorch.predispatch._remove_batch_dim(slice_2, 2, 1, 0); slice_2 = None
|
||||
_vmap_decrement_nesting_6 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_6 = None
|
||||
_remove_batch_dim_7 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_6, 1, 2, 0); _remove_batch_dim_6 = None
|
||||
_vmap_decrement_nesting_7 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_7 = None
|
||||
transpose_1 = torch.ops.aten.transpose.int(_remove_batch_dim_7, -2, -1); _remove_batch_dim_7 = None
|
||||
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(transpose_1, dtype = torch.int32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
|
||||
to_11 = torch.ops.aten.to.dtype(transpose_1, torch.int32); transpose_1 = None
|
||||
sum_6 = torch.ops.aten.sum.dim_IntList(to_11, [-1])
|
||||
argsort_3 = torch.ops.aten.argsort.stable(to_11, stable = True, descending = True); to_11 = None
|
||||
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(sum_6, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
|
||||
to_12 = torch.ops.aten.to.dtype(sum_6, torch.int32, False, False, torch.contiguous_format); sum_6 = None
|
||||
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(argsort_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
|
||||
to_13 = torch.ops.aten.to.dtype(argsort_3, torch.int32, False, False, torch.contiguous_format); argsort_3 = None
|
||||
linear_1 = torch.ops.aten.linear.default(linear, q_proj_weight, q_proj_bias); q_proj_weight = q_proj_bias = None
|
||||
view_1 = torch.ops.aten.view.default(linear_1, [2, 1, 128, 64]); linear_1 = None
|
||||
linear_2 = torch.ops.aten.linear.default(linear, k_proj_weight, k_proj_bias); k_proj_weight = k_proj_bias = None
|
||||
view_2 = torch.ops.aten.view.default(linear_2, [2, 1, 128, 64]); linear_2 = None
|
||||
linear_3 = torch.ops.aten.linear.default(linear, v_proj_weight, v_proj_bias); linear = v_proj_weight = v_proj_bias = None
|
||||
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
|
||||
sdpa_score0 = self.sdpa_score0
|
||||
sdpa_mask0 = self.sdpa_mask0
|
||||
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
|
||||
getitem = flex_attention[0]
|
||||
getitem_1 = flex_attention[1]; getitem_1 = None
|
||||
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None
|
||||
return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
||||
)
|
||||
exported_out = exported_mod(x)
|
||||
self.assertEqual(exported_out, eager_out)
|
||||
|
||||
def test_inductor_backend_inside_nonstrict(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def i_want_faster_code(inp1, inp2):
|
||||
nonlocal x
|
||||
return x + inp1 + inp2
|
||||
|
||||
out = torch.compile(i_want_faster_code)(x, x)
|
||||
return x + out
|
||||
|
||||
foo = Foo()
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "You are calling torch.compile inside torch.export region"
|
||||
):
|
||||
ep = export(foo, (torch.randn(4, 4),), strict=False).module()
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=4] = placeholder[target=x]
|
||||
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %x), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %add_1), kwargs = {})
|
||||
return (add_2,)""",
|
||||
)
|
||||
|
||||
def test_bincount(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -4945,8 +4945,6 @@ class <lambda>(torch.nn.Module):
|
||||
):
|
||||
cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None
|
||||
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
||||
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
|
@ -22,6 +22,7 @@ import platform
|
||||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable as _Callable
|
||||
from typing import (
|
||||
Any as _Any,
|
||||
@ -2634,6 +2635,28 @@ def compile(
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
"You are calling torch.compile inside torch.export region. "
|
||||
"To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)"
|
||||
)
|
||||
from torch._higher_order_ops.utils import setup_compilation_env
|
||||
|
||||
# Create wrapper that always uses eager backend during export
|
||||
def export_wrapped_fn(*args, **kwargs):
|
||||
with setup_compilation_env() as backend: # type: ignore[attr-defined]
|
||||
# Force eager backend regardless of original backend
|
||||
backend_wrapper = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
return torch._dynamo.optimize(
|
||||
backend=backend_wrapper,
|
||||
nopython=fullgraph,
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
)(model)(*args, **kwargs)
|
||||
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
|
@ -74,13 +74,13 @@ def make_eager_backend_with_torch_function_modes(
|
||||
def fn(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
stack = ExitStack()
|
||||
for mode in modes:
|
||||
stack.enter_context(mode)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with ExitStack() as stack:
|
||||
for mode in modes:
|
||||
stack.enter_context(mode)
|
||||
return gm.forward(*args, **kwargs)
|
||||
|
||||
result = gm.forward
|
||||
stack.close()
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
return fn
|
||||
|
||||
|
@ -19,7 +19,6 @@ from torch._C._functorch import (
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._higher_order_ops.utils import (
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
create_bw_fn,
|
||||
fill_none_with_masks,
|
||||
@ -33,12 +32,7 @@ from torch._higher_order_ops.utils import (
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
|
||||
@ -174,10 +168,6 @@ def cond(
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
||||
if isinstance(pred, (bool, int, float)):
|
||||
# This is the non-strict export case. Strict export and torch.compile are
|
||||
# handled above in dynamo.
|
||||
@ -223,21 +213,12 @@ def cond(
|
||||
def _cond_op_wrapper(*args, **kwargs):
|
||||
return cond_op(*args, **kwargs)
|
||||
|
||||
with (
|
||||
_set_compilation_env(),
|
||||
torch._dynamo.utils.disable_cache_limit(),
|
||||
_temp_remove_pre_dispatch_torch_function_mode(),
|
||||
):
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend: Union[str, Callable[..., Any]] = (
|
||||
make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
)
|
||||
else:
|
||||
backend = "eager"
|
||||
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
from torch._higher_order_ops.utils import setup_compilation_env
|
||||
|
||||
with setup_compilation_env() as backend:
|
||||
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
@ -96,19 +96,8 @@ def _maybe_run_with_interpreter(fn):
|
||||
|
||||
def _maybe_compile_and_run_fn(fn, *args):
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend: Union[str, Callable[..., Any]] = (
|
||||
make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
)
|
||||
else:
|
||||
backend = "eager"
|
||||
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
|
||||
with setup_compilation_env() as backend: # type: ignore[attr-defined]
|
||||
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
|
||||
else:
|
||||
return fn(*args)
|
||||
|
||||
@ -236,6 +225,34 @@ def check_meta_consistency(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_compilation_env():
|
||||
"""
|
||||
Context manager that sets up proper environment and backend when invoking torch.compile
|
||||
inside torch.export region or inside HOP.
|
||||
"""
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_modes,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
)
|
||||
|
||||
with (
|
||||
_set_compilation_env(),
|
||||
torch._dynamo.utils.disable_cache_limit(),
|
||||
_temp_remove_pre_dispatch_torch_function_mode() as pre_dispatch_mode,
|
||||
_temp_remove_metadata_torch_function_mode() as metadata_mode,
|
||||
):
|
||||
modes = [
|
||||
mode for mode in (pre_dispatch_mode, metadata_mode) if mode is not None
|
||||
]
|
||||
if modes:
|
||||
yield make_eager_backend_with_torch_function_modes(modes)
|
||||
else:
|
||||
yield "eager"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
|
Reference in New Issue
Block a user