mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Save indexing for getitem nodes when do custom replacements (#140193)
Fixes #137280 When we have multiple indexings for the same array as returned items in pattern replacement, we shouldn't ignore its indexing numbers. otherwise, we may create a wrong pattern_to_node mapping. A unit test is added in this PR. In this unit test, the function `rms_pattern_static` is replaced with `rms_replacement_static` when called. The function `rms_pattern_static` calls two functionalized custom operators, `torch.ops.vllm.rms_norm.default` and `torch.ops.vllm.static_scaled_int8_quant.default`, and it returns at2[1] and at2[2] as outputs. The function `rms_replacement_static` calls one functionalized custom operator `torch.ops.vllm.fused_rms_norm_quant_static.default`, which returns two corresponding items. Run `python test/inductor/test_pattern_matcher.py -k test_multioutput_register_replacement` to test. After set `TORCH_COMPILE_DEBUG` to 1, the final part of the `fx_graph_readable.py` is like the following. ```python # File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1673 in rms_pattern_static, code: at1 = auto_functionalized( auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.rms_norm.default, result = permute_1, input = convert_element_type, weight = convert_element_type_1, epsilon = 1e-06); permute_1 = convert_element_type = convert_element_type_1 = None getitem_1: "bf16[5, 4]" = auto_functionalized[1]; auto_functionalized = None # File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1680 in rms_pattern_static, code: at2 = auto_functionalized( auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.static_scaled_int8_quant.default, result = permute, input = getitem_1, scale = full_default, azp = None); permute = getitem_1 = full_default = None getitem_3: "i8[5, 4]" = auto_functionalized_1[1] getitem_4: "f32[1, 1]" = auto_functionalized_1[2]; auto_functionalized_1 = None return (getitem_3, getitem_4) ``` This happens before pattern matching, so is it expected to call `static_scaled_int8_quant` and `rms_norm` and return auto_functionalized_1 as outputs. However, for pytorch before this PR, the `fx_graph_transformed.py`, which is after pattern matching, has the following code. ```python # File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1748 in my_func_static, code: scale = torch.ones((1, 1)) full_default: "f32[1, 1]" = torch.ops.aten.full.default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) # No stacktrace found for following nodes as_strided_default: "i8[20]" = torch.ops.aten.as_strided.default(permute, [20], [1], 0) clone_default: "i8[20]" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None as_strided_default_1: "i8[5, 4]" = torch.ops.aten.as_strided.default(clone_default, [5, 4], [4, 1], 0); clone_default = None as_strided_default_2: "f32[1]" = torch.ops.aten.as_strided.default(full_default, [1], [1], 0) clone_default_1: "f32[1]" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None as_strided_default_3: "f32[1, 1]" = torch.ops.aten.as_strided.default(clone_default_1, [1, 1], [1, 1], 0); clone_default_1 = None static_scaled_int8_quant_default = torch.ops.vllm.static_scaled_int8_quant.default(as_strided_default_1, permute_1, as_strided_default_3); as_strided_default_1 = permute_1 = static_scaled_int8_quant_default = None fused_rms_norm_quant_static_default = torch.ops.vllm.fused_rms_norm_quant_static.default(permute, convert_element_type, convert_element_type_1, full_default, None, 1e-06); convert_element_type = convert_element_type_1 = full_default = fused_rms_norm_quant_static_default = None return (permute, as_strided_default_3) ``` Here, it returns `(permute, as_strided_default_3)` while `permute` is written by fused_rms_norm_quant_static and `as_strided_default_3` is written by `static_scaled_int8_quant`. This is wrong because in our expectation, the `static_scaled_int8_quant` should be removed since it is replaced with `fused_rms_norm_quant_static`. It is supposed to return `(permute, full_default)`. The root cause is the following part. When we [generate patterns](5f4a21dc58/torch/_inductor/pattern_matcher.py (L1580)
) with traced fx graph and call the following function, the indexing numbers' type int in traced graph are ignored in `ignore_types`. So, the final arguments of patterns for those two output items are like `(CallFunction(auto_functionalized,XXX)), *)`.5f4a21dc58/torch/_inductor/pattern_matcher.py (L1839-L1847)
When we do pattern matching after we generated patterns in the following part, the `sorted(itertools.chain.from_iterable(nodes), reverse=True)` is `[getitem_4, getitem_3, getitem_1]`. The getitem_4's iteration is always FailedMatch because we always use the first element to do the pattern match here (it fails on different match functions before and after this PR, but the reason is always the indexing numbers issue)d4cdc09881/torch/_inductor/pattern_matcher.py (L848)
. However, when we do pattern matching for getitem_3, the child_match returns a match for getitem_3 again which is because the `*` pattern can match anything. Then the getitem_3's pattern matching returns a `[getitem_3, getitem_3]` as outputs which are wrong.d4cdc09881/torch/_inductor/pattern_matcher.py (L856)
d4cdc09881/torch/_inductor/pattern_matcher.py (L1750-L1774)
This PR doesn't ignore `int` type when we generate patterns for getitem functions because integer indexing numbers are important to them. Thus, the indexing information is kept in patterns, ensuring correct matchings. With this PR, the above `child_match` returns a match for getitem_4, and the final getitem_3's pattern matching returns the correct `[getitem_3, getitem_4]`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140193 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
b37cfddeb3
commit
ab63b679e9
@ -1802,10 +1802,15 @@ def fx_to_pattern(
|
||||
inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
|
||||
assert len(inv_scalar_workaround) == len(scalar_workaround)
|
||||
|
||||
def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
|
||||
def process_arg(
|
||||
x: T, ignore_types_override: Optional[Sequence[Type[Any]]] = None
|
||||
) -> Union[T, KeywordArg, Ignored]:
|
||||
current_ignore_types = (
|
||||
ignore_types_override if ignore_types_override is not None else ignore_types
|
||||
)
|
||||
if isinstance(x, (float, int)) and x in inv_scalar_workaround:
|
||||
return KeywordArg(inv_scalar_workaround[x])
|
||||
if type(x) in ignore_types:
|
||||
if type(x) in current_ignore_types:
|
||||
return Ignored()
|
||||
if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
|
||||
return Ignored()
|
||||
@ -1838,11 +1843,25 @@ def fx_to_pattern(
|
||||
def call_function(
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||
) -> PatternExpr:
|
||||
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
||||
process_arg_fn = process_arg
|
||||
# Indexing is critical for matching getitem nodes, so we can't ignore int args here
|
||||
if target == operator.getitem:
|
||||
|
||||
def process_arg_fn_impl(
|
||||
x: T,
|
||||
ignore_types_override: Optional[Sequence[Type[Any]]] = tuple(
|
||||
t for t in ignore_types if t is not int
|
||||
),
|
||||
) -> Union[T, KeywordArg, Ignored]:
|
||||
return process_arg(x, ignore_types_override)
|
||||
|
||||
process_arg_fn = process_arg_fn_impl
|
||||
|
||||
args, kwargs = pytree.tree_map(process_arg_fn, (args, kwargs))
|
||||
if list in ignore_types:
|
||||
# Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
|
||||
args = [process_arg(a) for a in args]
|
||||
kwargs = {k: process_arg(a) for k, a in kwargs.items()}
|
||||
args = [process_arg_fn(a) for a in args]
|
||||
kwargs = {k: process_arg_fn(a) for k, a in kwargs.items()}
|
||||
return CallFunction(target, *args, **kwargs)
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
|
Reference in New Issue
Block a user