Support calling torch.compile inside non-strict export (#164171)

So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally.

So the things i did for fixing above were:
1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared.
2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked.
3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode.

With above fixes, we are able to export flex attention in export.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164171
Approved by: https://github.com/ydwu4
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-02 08:27:10 -07:00
committed by PyTorch MergeBot
parent 3288fbf374
commit 2a11ce2c78
6 changed files with 308 additions and 47 deletions

View File

@ -717,6 +717,248 @@ class TestExport(TestCase):
)
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
@requires_gpu
def test_flex_attention_export(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
class MixedFakeModeModel(torch.nn.Module):
def __init__(self, dim=64, use_inductor=True):
super().__init__()
self.dim = dim
self.q_proj = torch.nn.Linear(64, 64)
self.k_proj = torch.nn.Linear(64, 64)
self.v_proj = torch.nn.Linear(64, 64)
self.use_inductor = use_inductor
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Process input first - this creates fake tensors in export's fake mode
processed = self.q_proj(x)
# Create some computation that depends on processed tensor
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
# Now call create_block_mask which internally calls torch.compile
# The mask function will capture 'intermediate' which is a fake tensor
# from export's fake mode, but create_block_mask will create its own fake mode
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
# This captures the intermediate tensor from the outer scope
# When torch.compile is called inside create_block_mask,
# this tensor will be from export's fake mode while new tensors
# created inside will be from the nested fake mode
threshold = intermediate[
batch_idx, q_idx % seq_len
] # Access the captured tensor
return (kv_idx <= q_idx) & (threshold > 0) # Mix fake modes
block_mask = create_block_mask(
mask_mod=dynamic_mask_function,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=x.device,
)
q = self.q_proj(processed).view(batch_size, 1, seq_len, self.dim)
k = self.k_proj(processed).view(batch_size, 1, seq_len, self.dim)
v = self.v_proj(processed).view(batch_size, 1, seq_len, self.dim)
# Use flex_attention with the problematic block_mask
backend = "inductor" if self.use_inductor else "eager"
out = torch.compile(flex_attention, backend=backend)(
q, k, v, block_mask=block_mask
)
return out
model = MixedFakeModeModel(use_inductor=False)
x = torch.randn(2, 128, 64)
# Inductor doesn't work in eager mode flex attention
eager_out = model(x)
model.use_inductor = True
exported_mod = torch.export.export(model, (x,), strict=False).module()
self.assertExpectedInline(
str(exported_mod.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
q_proj_weight = self.q_proj.weight
q_proj_bias = self.q_proj.bias
k_proj_weight = self.k_proj.weight
k_proj_bias = self.k_proj.bias
v_proj_weight = self.v_proj.weight
v_proj_bias = self.v_proj.bias
_guards_fn = self._guards_fn(x); _guards_fn = None
linear = torch.ops.aten.linear.default(x, q_proj_weight, q_proj_bias); x = None
sum_1 = torch.ops.aten.sum.dim_IntList(linear, [-1])
detach = torch.ops.aten.detach.default(sum_1); sum_1 = None
arange = torch.ops.aten.arange.start(0, 2, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.start(0, 1, device = device(type='cpu'), pin_memory = False)
arange_2 = torch.ops.aten.arange.start(0, 128, device = device(type='cpu'), pin_memory = False)
arange_3 = torch.ops.aten.arange.start(0, 128, device = device(type='cpu'), pin_memory = False)
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
_add_batch_dim = torch._functorch.predispatch._add_batch_dim(arange, 0, 1); arange = None
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
_add_batch_dim_1 = torch._functorch.predispatch._add_batch_dim(arange_1, 0, 2); arange_1 = _add_batch_dim_1 = None
lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_2 = None
_vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(128, 'error'); _vmap_increment_nesting_2 = None
_add_batch_dim_2 = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 3); arange_2 = None
lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_3 = None
_vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(128, 'error'); _vmap_increment_nesting_3 = None
_add_batch_dim_3 = torch._functorch.predispatch._add_batch_dim(arange_3, 0, 4); arange_3 = None
remainder = torch.ops.aten.remainder.Scalar(_add_batch_dim_2, 128)
torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0
function_const_func_spec0 = self.function_const_func_spec0
flat_apply = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', detach, _add_batch_dim, remainder); function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = _add_batch_dim = remainder = None
le = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2); _add_batch_dim_3 = _add_batch_dim_2 = None
gt = torch.ops.aten.gt.Scalar(flat_apply, 0); flat_apply = None
and_1 = torch.ops.aten.__and__.Tensor(le, gt); le = gt = None
_remove_batch_dim = torch._functorch.predispatch._remove_batch_dim(and_1, 4, 128, 0); and_1 = None
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
_remove_batch_dim_1 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 128, 0); _remove_batch_dim = None
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
_remove_batch_dim_2 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0)
expand = torch.ops.aten.expand.default(_remove_batch_dim_1, [1, 128, 128]); _remove_batch_dim_1 = expand = None
_vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
_remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0); _remove_batch_dim_2 = None
_vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
pad = torch.ops.aten.pad.default(_remove_batch_dim_3, [0, 0, 0, 0]); _remove_batch_dim_3 = None
view = torch.ops.aten.view.default(pad, [2, 1, 1, 128, 1, 128]); pad = None
permute = torch.ops.aten.permute.default(view, [0, 1, 2, 4, 3, 5]); view = None
sum_2 = torch.ops.aten.sum.dim_IntList(permute, [-2, -1]); permute = None
eq = torch.ops.aten.eq.Scalar(sum_2, 16384)
gt_1 = torch.ops.aten.gt.Scalar(sum_2, 0)
lt = torch.ops.aten.lt.Scalar(sum_2, 16384); sum_2 = None
and_2 = torch.ops.aten.__and__.Tensor(gt_1, lt); gt_1 = lt = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(and_2, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to = torch.ops.aten.to.dtype(and_2, torch.int8); and_2 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(eq, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1 = torch.ops.aten.to.dtype(eq, torch.int8); eq = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(to, dtype = torch.int8, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2 = torch.ops.aten.to.dtype(to, torch.int32); to = None
sum_3 = torch.ops.aten.sum.dim_IntList(to_2, [-1])
argsort = torch.ops.aten.argsort.stable(to_2, stable = True, descending = True); to_2 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(sum_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3 = torch.ops.aten.to.dtype(sum_3, torch.int32, False, False, torch.contiguous_format); sum_3 = None
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(argsort, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4 = torch.ops.aten.to.dtype(argsort, torch.int32, False, False, torch.contiguous_format); argsort = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_1, dtype = torch.int8, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5 = torch.ops.aten.to.dtype(to_1, torch.int32); to_1 = None
sum_4 = torch.ops.aten.sum.dim_IntList(to_5, [-1])
argsort_1 = torch.ops.aten.argsort.stable(to_5, stable = True, descending = True); to_5 = None
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(sum_4, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6 = torch.ops.aten.to.dtype(sum_4, torch.int32, False, False, torch.contiguous_format); sum_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(argsort_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7 = torch.ops.aten.to.dtype(argsort_1, torch.int32, False, False, torch.contiguous_format); argsort_1 = None
lazy_load_decompositions_4 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_4 = None
_vmap_increment_nesting_4 = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_4 = None
_add_batch_dim_4 = torch._functorch.predispatch._add_batch_dim(to_3, 0, 1)
_add_batch_dim_5 = torch._functorch.predispatch._add_batch_dim(to_4, 0, 1)
lazy_load_decompositions_5 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_5 = None
_vmap_increment_nesting_5 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_5 = None
_add_batch_dim_6 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_4, 0, 2); _add_batch_dim_4 = None
_add_batch_dim_7 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_5, 0, 2); _add_batch_dim_5 = None
new_zeros = torch.ops.aten.new_zeros.default(_add_batch_dim_7, [1, 2], dtype = torch.int32, pin_memory = False)
arange_4 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
unsqueeze = torch.ops.aten.unsqueeze.default(arange_4, -1); arange_4 = None
arange_5 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(_add_batch_dim_6, -1); _add_batch_dim_6 = None
lt_1 = torch.ops.aten.lt.Tensor(arange_5, unsqueeze_1); arange_5 = unsqueeze_1 = None
where = torch.ops.aten.where.ScalarOther(lt_1, _add_batch_dim_7, 1); lt_1 = _add_batch_dim_7 = None
new_ones = torch.ops.aten.new_ones.default(new_zeros, [], pin_memory = False)
index_put_ = torch.ops.aten.index_put_.default(new_zeros, [unsqueeze, where], new_ones); new_zeros = unsqueeze = where = new_ones = None
slice_1 = torch.ops.aten.slice.Tensor(index_put_, 1, 0, 1); index_put_ = None
_remove_batch_dim_4 = torch._functorch.predispatch._remove_batch_dim(slice_1, 2, 1, 0); slice_1 = None
_vmap_decrement_nesting_4 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_4 = None
_remove_batch_dim_5 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_4, 1, 2, 0); _remove_batch_dim_4 = None
_vmap_decrement_nesting_5 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_5 = None
transpose = torch.ops.aten.transpose.int(_remove_batch_dim_5, -2, -1); _remove_batch_dim_5 = None
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(transpose, dtype = torch.int32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8 = torch.ops.aten.to.dtype(transpose, torch.int32); transpose = None
sum_5 = torch.ops.aten.sum.dim_IntList(to_8, [-1])
argsort_2 = torch.ops.aten.argsort.stable(to_8, stable = True, descending = True); to_8 = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(sum_5, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9 = torch.ops.aten.to.dtype(sum_5, torch.int32, False, False, torch.contiguous_format); sum_5 = None
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(argsort_2, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10 = torch.ops.aten.to.dtype(argsort_2, torch.int32, False, False, torch.contiguous_format); argsort_2 = None
lazy_load_decompositions_6 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_6 = None
_vmap_increment_nesting_6 = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_6 = None
_add_batch_dim_8 = torch._functorch.predispatch._add_batch_dim(to_6, 0, 1)
_add_batch_dim_9 = torch._functorch.predispatch._add_batch_dim(to_7, 0, 1)
lazy_load_decompositions_7 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_7 = None
_vmap_increment_nesting_7 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_7 = None
_add_batch_dim_10 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_8, 0, 2); _add_batch_dim_8 = None
_add_batch_dim_11 = torch._functorch.predispatch._add_batch_dim(_add_batch_dim_9, 0, 2); _add_batch_dim_9 = None
new_zeros_1 = torch.ops.aten.new_zeros.default(_add_batch_dim_11, [1, 2], dtype = torch.int32, pin_memory = False)
arange_6 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
unsqueeze_2 = torch.ops.aten.unsqueeze.default(arange_6, -1); arange_6 = None
arange_7 = torch.ops.aten.arange.default(1, dtype = torch.int32, device = device(type='cpu'), pin_memory = False)
unsqueeze_3 = torch.ops.aten.unsqueeze.default(_add_batch_dim_10, -1); _add_batch_dim_10 = None
lt_2 = torch.ops.aten.lt.Tensor(arange_7, unsqueeze_3); arange_7 = unsqueeze_3 = None
where_1 = torch.ops.aten.where.ScalarOther(lt_2, _add_batch_dim_11, 1); lt_2 = _add_batch_dim_11 = None
new_ones_1 = torch.ops.aten.new_ones.default(new_zeros_1, [], pin_memory = False)
index_put__1 = torch.ops.aten.index_put_.default(new_zeros_1, [unsqueeze_2, where_1], new_ones_1); new_zeros_1 = unsqueeze_2 = where_1 = new_ones_1 = None
slice_2 = torch.ops.aten.slice.Tensor(index_put__1, 1, 0, 1); index_put__1 = None
_remove_batch_dim_6 = torch._functorch.predispatch._remove_batch_dim(slice_2, 2, 1, 0); slice_2 = None
_vmap_decrement_nesting_6 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_6 = None
_remove_batch_dim_7 = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_6, 1, 2, 0); _remove_batch_dim_6 = None
_vmap_decrement_nesting_7 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_7 = None
transpose_1 = torch.ops.aten.transpose.int(_remove_batch_dim_7, -2, -1); _remove_batch_dim_7 = None
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(transpose_1, dtype = torch.int32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11 = torch.ops.aten.to.dtype(transpose_1, torch.int32); transpose_1 = None
sum_6 = torch.ops.aten.sum.dim_IntList(to_11, [-1])
argsort_3 = torch.ops.aten.argsort.stable(to_11, stable = True, descending = True); to_11 = None
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(sum_6, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12 = torch.ops.aten.to.dtype(sum_6, torch.int32, False, False, torch.contiguous_format); sum_6 = None
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(argsort_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13 = torch.ops.aten.to.dtype(argsort_3, torch.int32, False, False, torch.contiguous_format); argsort_3 = None
linear_1 = torch.ops.aten.linear.default(linear, q_proj_weight, q_proj_bias); q_proj_weight = q_proj_bias = None
view_1 = torch.ops.aten.view.default(linear_1, [2, 1, 128, 64]); linear_1 = None
linear_2 = torch.ops.aten.linear.default(linear, k_proj_weight, k_proj_bias); k_proj_weight = k_proj_bias = None
view_2 = torch.ops.aten.view.default(linear_2, [2, 1, 128, 64]); linear_2 = None
linear_3 = torch.ops.aten.linear.default(linear, v_proj_weight, v_proj_bias); linear = v_proj_weight = v_proj_bias = None
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
sdpa_score0 = self.sdpa_score0
sdpa_mask0 = self.sdpa_mask0
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
getitem = flex_attention[0]
getitem_1 = flex_attention[1]; getitem_1 = None
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None
return pytree.tree_unflatten((getitem,), self._out_spec)""",
)
exported_out = exported_mod(x)
self.assertEqual(exported_out, eager_out)
def test_inductor_backend_inside_nonstrict(self):
class Foo(torch.nn.Module):
def forward(self, x):
def i_want_faster_code(inp1, inp2):
nonlocal x
return x + inp1 + inp2
out = torch.compile(i_want_faster_code)(x, x)
return x + out
foo = Foo()
with self.assertWarnsRegex(
UserWarning, "You are calling torch.compile inside torch.export region"
):
ep = export(foo, (torch.randn(4, 4),), strict=False).module()
self.assertExpectedInline(
str(ep.graph).strip(),
"""\
graph():
%x : [num_users=4] = placeholder[target=x]
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %x), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %add_1), kwargs = {})
return (add_2,)""",
)
def test_bincount(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -4945,8 +4945,6 @@ class <lambda>(torch.nn.Module):
):
cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
_set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None

View File

@ -22,6 +22,7 @@ import platform
import sys
import textwrap
import threading
import warnings
from collections.abc import Callable as _Callable
from typing import (
Any as _Any,
@ -2634,6 +2635,28 @@ def compile(
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
if torch.compiler.is_exporting():
warnings.warn(
"You are calling torch.compile inside torch.export region. "
"To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)"
)
from torch._higher_order_ops.utils import setup_compilation_env
# Create wrapper that always uses eager backend during export
def export_wrapped_fn(*args, **kwargs):
with setup_compilation_env() as backend: # type: ignore[attr-defined]
# Force eager backend regardless of original backend
backend_wrapper = _TorchCompileWrapper(backend, mode, options, dynamic)
return torch._dynamo.optimize(
backend=backend_wrapper,
nopython=fullgraph,
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
)(model)(*args, **kwargs)
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:

View File

@ -74,13 +74,13 @@ def make_eager_backend_with_torch_function_modes(
def fn(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
stack = ExitStack()
for mode in modes:
stack.enter_context(mode)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with ExitStack() as stack:
for mode in modes:
stack.enter_context(mode)
return gm.forward(*args, **kwargs)
result = gm.forward
stack.close()
return result
return wrapper
return fn

View File

@ -19,7 +19,6 @@ from torch._C._functorch import (
from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import (
_maybe_run_with_interpreter,
_set_compilation_env,
check_input_alias_and_mutation_return_outputs,
create_bw_fn,
fill_none_with_masks,
@ -33,12 +32,7 @@ from torch._higher_order_ops.utils import (
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
@ -174,10 +168,6 @@ def cond(
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
if isinstance(pred, (bool, int, float)):
# This is the non-strict export case. Strict export and torch.compile are
# handled above in dynamo.
@ -223,21 +213,12 @@ def cond(
def _cond_op_wrapper(*args, **kwargs):
return cond_op(*args, **kwargs)
with (
_set_compilation_env(),
torch._dynamo.utils.disable_cache_limit(),
_temp_remove_pre_dispatch_torch_function_mode(),
):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
pred, true_fn, false_fn, operands
)
from torch._higher_order_ops.utils import setup_compilation_env
with setup_compilation_env() as backend:
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
pred, true_fn, false_fn, operands
)
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):

View File

@ -96,19 +96,8 @@ def _maybe_run_with_interpreter(fn):
def _maybe_compile_and_run_fn(fn, *args):
if not torch.compiler.is_dynamo_compiling():
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
with setup_compilation_env() as backend: # type: ignore[attr-defined]
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
else:
return fn(*args)
@ -236,6 +225,34 @@ def check_meta_consistency(
)
@contextmanager
def setup_compilation_env():
"""
Context manager that sets up proper environment and backend when invoking torch.compile
inside torch.export region or inside HOP.
"""
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_modes,
)
from torch.fx.experimental.proxy_tensor import (
_temp_remove_pre_dispatch_torch_function_mode,
)
with (
_set_compilation_env(),
torch._dynamo.utils.disable_cache_limit(),
_temp_remove_pre_dispatch_torch_function_mode() as pre_dispatch_mode,
_temp_remove_metadata_torch_function_mode() as metadata_mode,
):
modes = [
mode for mode in (pre_dispatch_mode, metadata_mode) if mode is not None
]
if modes:
yield make_eager_backend_with_torch_function_modes(modes)
else:
yield "eager"
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag