Compare commits

...

44 Commits

Author SHA1 Message Date
b5c5ce5b25 Update
[ghstack-poisoned]
2025-11-04 12:59:32 +00:00
dc657c1569 Update (base update)
[ghstack-poisoned]
2025-11-04 12:59:32 +00:00
7a04e7aed5 Update
[ghstack-poisoned]
2025-11-04 12:38:30 +00:00
36924d5623 Update (base update)
[ghstack-poisoned]
2025-11-04 12:38:30 +00:00
5079ee2c48 Update
[ghstack-poisoned]
2025-11-04 12:15:15 +00:00
391217370c Update (base update)
[ghstack-poisoned]
2025-11-04 12:15:15 +00:00
8ca7183d3b Update
[ghstack-poisoned]
2025-11-04 12:00:50 +00:00
9fb6727e34 Update (base update)
[ghstack-poisoned]
2025-11-04 12:00:50 +00:00
fd9903827d Update
[ghstack-poisoned]
2025-11-03 15:55:53 +00:00
0df8fe2fe9 Update (base update)
[ghstack-poisoned]
2025-11-03 15:55:53 +00:00
a8fc7a299b Update
[ghstack-poisoned]
2025-11-03 12:35:05 +00:00
503b20611a Update (base update)
[ghstack-poisoned]
2025-11-03 11:55:11 +00:00
d538ecd4eb Update
[ghstack-poisoned]
2025-11-03 11:55:11 +00:00
214c5393a2 Update (base update)
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
203c9313d9 Update
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
6b370c69ed Update (base update)
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
5c418c2cad Update
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
5eba9e1356 Update
[ghstack-poisoned]
2025-10-30 15:44:01 +00:00
778444e448 Update (base update)
[ghstack-poisoned]
2025-10-30 14:29:25 +00:00
0293eb43d4 Update
[ghstack-poisoned]
2025-10-30 14:29:25 +00:00
4eb096c72e Update (base update)
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
bc2bd64a95 Update
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
30f5561cb2 Update (base update)
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
2bb948a55b Update
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
cdb39008fb Update (base update)
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
46994bfa97 Update
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
8261d7e31e Update (base update)
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
6c1ea73522 Update
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
4e84146ff5 Update (base update)
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
2cda739df2 Update
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
8e8f0c66b3 Update (base update)
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
835382cc91 Update
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
ae7df501db Update (base update)
[ghstack-poisoned]
2025-10-28 15:45:48 +00:00
287e37eca4 Update
[ghstack-poisoned]
2025-10-28 15:45:48 +00:00
87fbc25108 Update (base update)
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
0da74ab1ff Update
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
becd1b0a18 Update (base update)
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
61e809099d Update
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
e24406722e Update (base update)
[ghstack-poisoned]
2025-10-28 15:12:09 +00:00
bfbaa98f53 Update
[ghstack-poisoned]
2025-10-28 15:12:09 +00:00
d7f0fb6af1 Update (base update)
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
d68f4b2141 Update
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
78f4119a3c Update (base update)
[ghstack-poisoned]
2025-10-28 14:49:44 +00:00
3ef3eb5b16 Update
[ghstack-poisoned]
2025-10-28 14:49:44 +00:00
6 changed files with 112 additions and 31 deletions

View File

@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
softmax_load_str in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -15280,7 +15280,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15293,7 +15293,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -1693,7 +1693,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
from torch._higher_order_ops.utils import _maybe_fake_tracing
from torch._inductor.utils import is_pointwise_use
from torch._inductor.utils import has_only_pointwise_uses
with tx.fake_mode:
sub_args_fake = [
@ -1712,9 +1712,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
for node in fx.graph.nodes:
# Check that the combine_fn is pointwise, if combine_mode='pointwise'
if not all(
is_pointwise_use(use) or use.op == "output" for use in node.users
):
if not has_only_pointwise_uses(node, select_output=True):
raise RuntimeError(
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
)

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern(

View File

@ -80,9 +80,9 @@ from .ir import (
from .utils import (
ceildiv,
decode_device,
has_only_pointwise_uses,
is_dynamic,
is_gpu,
is_pointwise_use,
is_view,
needs_fallback_due_to_atomic_add_limitations,
pad_listlike,
@ -1850,10 +1850,7 @@ def cat(inputs, dim=0):
(len(inputs) <= config.max_pointwise_cat_inputs)
and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
):
pointwise_uses = all(
is_pointwise_use(use, additional_pointwise_ops)
for use in V.current_node.users
)
pointwise_uses = has_only_pointwise_uses(V.current_node)
# fuse in case we will be used in a pointwise node, and there are any inputs we
# we can prevent materialization of.
fuse_pointwise_use = (

View File

@ -525,28 +525,107 @@ def is_view(op: torch._ops.OpOverload) -> bool:
return any(a.alias_info is not None for a in op._schema.arguments)
def is_pointwise_use(
use: Node,
is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if select_output and use.op == "output":
return True
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_only_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
*,
select_output: bool = False,
) -> bool:
return has_uses(
target, use_selector_fn, LogicalConnective.AND, select_output=select_output
)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target,
lambda use: any(tag in use_tags for tag in use.tags),
use_aggregate_type,
select_output=select_output,
)
def has_only_pointwise_uses(
target: Node,
*,
select_output: bool = False,
) -> bool:
"""
Do all uses of target have torch.Tag.pointwise?
Uses in views ops will follow the views uses
"""
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
if target is operator.getitem or is_view(target):
return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
return has_uses_tagged_as(
target,
use_tags=(torch.Tag.pointwise,),
use_aggregate_type=LogicalConnective.AND,
select_output=select_output,
)
def gen_gm_and_inputs(