mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266. Reverts * https://github.com/pytorch/pytorch/pull/135503 * https://github.com/pytorch/pytorch/pull/135502 * https://github.com/pytorch/pytorch/pull/135422 This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation. ``` import torch from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention from torch._inductor.lowering import make_pointwise, register_lowering from torch._inductor.virtualized import ops from torch.nn.attention.flex_attention import create_block_mask torch.set_default_device('cuda') flex_attention = torch.compile(flex_attention, dynamic=False) prefix_lengths = torch.arange(8) def prefix_lm(b, h, q, kv): return prefix_lengths[b] >= kv mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136590 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
529b6ab0bb
commit
289df45cee
@ -187,7 +187,6 @@ def debug_insert_nops(
|
||||
local_scope=locals(),
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
|
||||
Reference in New Issue
Block a user