Compare commits

...

60 Commits

Author SHA1 Message Date
34f650a2bd Update
[ghstack-poisoned]
2025-11-04 12:59:31 +00:00
a724a05f74 Update
[ghstack-poisoned]
2025-11-04 12:38:30 +00:00
4df2ac6607 Update
[ghstack-poisoned]
2025-11-04 12:15:15 +00:00
67550c6aaa Update
[ghstack-poisoned]
2025-11-04 12:00:49 +00:00
5769d71e9c Update (base update)
[ghstack-poisoned]
2025-11-04 12:00:49 +00:00
2c13372f1c Update
[ghstack-poisoned]
2025-11-03 11:55:11 +00:00
0b42bb3f61 Update (base update)
[ghstack-poisoned]
2025-11-03 11:55:11 +00:00
3467d3e123 Update
[ghstack-poisoned]
2025-10-31 10:18:28 +00:00
f5ee26463a Update (base update)
[ghstack-poisoned]
2025-10-31 10:18:28 +00:00
ab62572b18 Update
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
f7f0cc0ace Update (base update)
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
59ca356557 Update
[ghstack-poisoned]
2025-10-30 14:29:17 +00:00
c3e8577183 Update (base update)
[ghstack-poisoned]
2025-10-30 14:29:17 +00:00
35613ea658 Update
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
5be840ccc2 Update (base update)
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
57b5d96fcd Update
[ghstack-poisoned]
2025-10-29 15:42:53 +00:00
5d02965b7c Update (base update)
[ghstack-poisoned]
2025-10-29 15:42:53 +00:00
d220390880 Update
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
c6cfcf49e1 Update (base update)
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
56c0ca21f0 Update
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
85b7edb52b Update (base update)
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
02fa1ad97a Update
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
c2eb709432 Update (base update)
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
c1e7268182 Update
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
acc92f8dc1 Update (base update)
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
e50c1a04b7 Update
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
983443cd20 Update
[ghstack-poisoned]
2025-10-28 15:45:48 +00:00
b76d9cfc7f Update (base update)
[ghstack-poisoned]
2025-10-28 15:39:22 +00:00
d8c4903a3e Update
[ghstack-poisoned]
2025-10-28 15:39:22 +00:00
7f855e5590 Update
[ghstack-poisoned]
2025-10-28 15:20:43 +00:00
7ba226eb14 Update
[ghstack-poisoned]
2025-10-28 15:12:09 +00:00
44bac1e070 Update (base update)
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
24cdf875b8 Update
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
a9888afe19 Update
[ghstack-poisoned]
2025-10-28 14:49:44 +00:00
10df61b3c2 Update (base update)
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
f7d934e8a7 Update
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
7af0937c58 Update (base update)
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
b41d593878 Update
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
fcf212b2b7 Update (base update)
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
8f71493b92 Update
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
a9117e9028 Update (base update)
[ghstack-poisoned]
2025-10-28 12:02:21 +00:00
d7c68ae739 Update
[ghstack-poisoned]
2025-10-28 12:02:21 +00:00
172ff9f1d3 Update (base update)
[ghstack-poisoned]
2025-10-28 11:48:26 +00:00
09ae386f48 Update
[ghstack-poisoned]
2025-10-28 11:48:26 +00:00
a137f705d2 Update (base update)
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
9af5881598 Update
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
5123a3ad68 Update (base update)
[ghstack-poisoned]
2025-10-27 15:12:20 +00:00
37da895a9b Update
[ghstack-poisoned]
2025-10-27 15:12:20 +00:00
885d7b9f8d Update
[ghstack-poisoned]
2025-10-27 12:38:27 +00:00
135a48757d Update
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
072cef4b11 Update (base update)
[ghstack-poisoned]
2025-10-27 12:04:22 +00:00
eaea290ced Update
[ghstack-poisoned]
2025-10-27 12:04:22 +00:00
df91f285d6 Update
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
f69fad4130 Update (base update)
[ghstack-poisoned]
2025-10-27 11:23:44 +00:00
3c20f6ba8d Update
[ghstack-poisoned]
2025-10-27 11:23:44 +00:00
ae6f926ede Update (base update)
[ghstack-poisoned]
2025-10-24 16:55:03 +00:00
cec4bcda84 Update
[ghstack-poisoned]
2025-10-24 16:55:03 +00:00
37aa7f9c7e Update
[ghstack-poisoned]
2025-10-24 14:27:27 +00:00
b4fffb32de Update (base update)
[ghstack-poisoned]
2025-10-24 14:08:43 +00:00
9f0c3473b0 Update
[ghstack-poisoned]
2025-10-24 14:08:43 +00:00
4 changed files with 77 additions and 6 deletions

View File

@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
softmax_load_str in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -15280,7 +15280,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15293,7 +15293,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern(

View File

@ -549,6 +549,70 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]: