Update mutation checking in pattern matcher (#137448)

Fix for https://github.com/pytorch/pytorch/issues/137229

The current mutation checking is complicated because it works for pre grad IR. When pre grad ir has been traced to OpOverloads checking is much easier. I am also special casing auto functional hop although I discussed with @zou3519 it would be nice to have a way of querying HOPs that mimic schemas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137448
Approved by: https://github.com/zou3519
This commit is contained in:
eellison
2024-10-07 16:37:54 -07:00
committed by PyTorch MergeBot
parent 28493efe6e
commit 8184e202d8
2 changed files with 32 additions and 0 deletions

View File

@ -1460,6 +1460,18 @@ class TestPatternMatcher(TestCase):
expect=False,
)
@torch.library.custom_op("vllm::fused_rms_norm_quant_static", mutates_args=[])
def fused_rms_norm_quant_static(out: torch.Tensor, input: torch.Tensor) -> None:
pass
check(
"call_function",
torch.ops.vllm.fused_rms_norm_quant_static,
(t, t),
{},
expect=False,
)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:

View File

@ -1616,7 +1616,27 @@ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
def fixme_incorrect_inductor_schema_op(op: torch._ops.OpOverload) -> bool:
if op.namespace != "inductor":
return False
# TODO - fix schema
# Dont add any more !
return op in (
torch.ops.inductor.accumulate_grad_.default,
torch.ops.inductor.resize_storage_bytes_.default,
)
def is_mutation_op(node: torch.fx.Node) -> bool:
if isinstance(
node.target, torch._ops.OpOverload
) and not fixme_incorrect_inductor_schema_op(node.target):
return node.target._schema.is_mutable
elif isinstance(
node.target, torch._higher_order_ops.auto_functionalize.AutoFunctionalized
):
return False
if node.op == "call_function":
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
return True