mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update mutation checking in pattern matcher (#137448)
Fix for https://github.com/pytorch/pytorch/issues/137229 The current mutation checking is complicated because it works for pre grad IR. When pre grad ir has been traced to OpOverloads checking is much easier. I am also special casing auto functional hop although I discussed with @zou3519 it would be nice to have a way of querying HOPs that mimic schemas. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137448 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
28493efe6e
commit
8184e202d8
@ -1460,6 +1460,18 @@ class TestPatternMatcher(TestCase):
|
||||
expect=False,
|
||||
)
|
||||
|
||||
@torch.library.custom_op("vllm::fused_rms_norm_quant_static", mutates_args=[])
|
||||
def fused_rms_norm_quant_static(out: torch.Tensor, input: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
check(
|
||||
"call_function",
|
||||
torch.ops.vllm.fused_rms_norm_quant_static,
|
||||
(t, t),
|
||||
{},
|
||||
expect=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA:
|
||||
|
@ -1616,7 +1616,27 @@ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
|
||||
_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
|
||||
|
||||
|
||||
def fixme_incorrect_inductor_schema_op(op: torch._ops.OpOverload) -> bool:
|
||||
if op.namespace != "inductor":
|
||||
return False
|
||||
|
||||
# TODO - fix schema
|
||||
# Dont add any more !
|
||||
return op in (
|
||||
torch.ops.inductor.accumulate_grad_.default,
|
||||
torch.ops.inductor.resize_storage_bytes_.default,
|
||||
)
|
||||
|
||||
|
||||
def is_mutation_op(node: torch.fx.Node) -> bool:
|
||||
if isinstance(
|
||||
node.target, torch._ops.OpOverload
|
||||
) and not fixme_incorrect_inductor_schema_op(node.target):
|
||||
return node.target._schema.is_mutable
|
||||
elif isinstance(
|
||||
node.target, torch._higher_order_ops.auto_functionalize.AutoFunctionalized
|
||||
):
|
||||
return False
|
||||
if node.op == "call_function":
|
||||
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
|
||||
return True
|
||||
|
Reference in New Issue
Block a user