Compare commits

...

11 Commits

Author SHA1 Message Date
8f28e7ccbf Make line shorter 2025-09-08 11:27:12 -07:00
dbb85070cb Fix lint 2025-09-08 11:17:59 -07:00
c0b3fc4190 Address fixes for PR 2025-09-08 11:11:51 -07:00
9fb01a19ca fix test 2025-08-29 11:25:24 -07:00
6e996f9d87 Privatize API for now 2025-08-29 09:30:18 -07:00
187c48bd09 Add more descript to comment 2025-08-29 09:30:18 -07:00
2e1523b034 Move test formally 2025-08-29 09:30:18 -07:00
9aa9b94add working testcase 2025-08-29 09:30:18 -07:00
6cb0e997f6 Update testcase and impl to accurately reflect need 2025-08-29 09:30:18 -07:00
f42dec5515 Working example to record ops 2025-08-29 09:30:18 -07:00
6af528bdfd Basic commit try to get impl supported 2025-08-29 09:30:18 -07:00
3 changed files with 136 additions and 10 deletions

View File

@ -19,6 +19,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_TORCHDYNAMO,
TestCase,
)
from torch.testing._internal.triton_utils import requires_cuda_and_triton
try:
@ -823,6 +824,63 @@ class TestFlopCounter(TestCase):
self.assertEqual(called, 1)
self.assertExpectedInline(get_total_flops(mode), """9001""")
@requires_cuda_and_triton
def test_flop_counter_custom_triton(self):
import triton
import triton.language as tl
from torch.utils.flop_counter import (
_FlopCounterMode,
_register_flop_formula_for_triton_kernel,
)
@triton.jit
def sin_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
out = tl.sin(x)
tl.store(out_ptr + offsets, out, mask=mask)
x = torch.randn(3, device="cuda")
out = torch.empty(3, device="cuda")
@_register_flop_formula_for_triton_kernel(sin_kernel)
def compute_sin_kerenel_flops(*args, **kwargs) -> int:
# dummy implementation
return 2
def sin_grid(meta):
return (triton.cdiv(3, meta["BLOCK_SIZE"]),)
with FlopCounterMode() as m:
n_elements = 3
torch.library.wrap_triton(sin_kernel)[sin_grid](x, out, 3, 256)
self.assertExpectedInline(get_total_flops(m), """2""")
# Now, wrap in a triton op and do the decomp
@torch._library.triton.triton_op("mylib::sin_op", mutates_args=())
def op() -> None:
n_elements = 3
torch.library.wrap_triton(sin_kernel)[sin_grid](x, out, 3, 256)
def op_decompose(mode, *args, **kwargs):
with mode:
n_elements = 3
torch.library.wrap_triton(sin_kernel)[sin_grid](x, out, 3, 256)
torch.library.register_torch_dispatch(
"mylib::sin_op", _FlopCounterMode, op_decompose
)
# Should now output 2 flops; previously would be 0
with FlopCounterMode() as m2:
torch.ops.mylib.sin_op()
self.assertExpectedInline(get_total_flops(m2), """2""")
@skipIfNoTorchVision
def test_inference_mode(self):
def get_flops(model):

View File

@ -1016,6 +1016,10 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
def get_kernel(kernel_idx: int) -> "TritonKernelType":
return kernel_side_table.get_kernel(kernel_idx)
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(
*,

View File

@ -54,6 +54,60 @@ def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]]
return register_fun
triton_flop_registry: dict[Any, Any] = {}
def _register_flop_formula_for_triton_kernel(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
This experimental API is used to register flop formula for triton kernel in `targets`
This is separate from `register_flop_formula` because triton kernel
currently gets decomposed into `triton_kernel_wrapper_functional(kernel)`
in aot autograd; so this function requires
1) The user to write decomposition of the custom operator with mode:
`
op_decompose(mode, ...):
with mode:
<op implementation>
kernel
<...>
`
2) The decomposition registered with the operator's flop counter mode
torch dispatch
`
torch.library.register_torch_dispatch(
"mylib::op", _FlopCounterMode, op_decompose
)
`
3) The user to register the flop formula for the triton kernel in this API
`
@register_flop_formula_for_triton_kernel(kernel)
def compute_fn(*args, **kwargs) -> int:
<impl>
`
We then use the `triton_flop_registry` to get the flop formula for the kernel name
when we encounter `triton_kernel_wrapper_functional(kernel)` in the graph.
NOTE: This API is unstable and subject to change. We expect to streamline
steps 1) and 2) in the future, as well as default support for this
registration in the case of 1:1 op to kernel mapping.
"""
def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
if not get_raw:
flop_formula = shape_wrapper(flop_formula)
def register(target):
if target in triton_flop_registry:
raise RuntimeError(f"duplicate registrations for {target}")
triton_flop_registry[target] = flop_formula
# To handle allowing multiple aten_ops at once
torch.utils._pytree.tree_map_(register, targets)
return flop_formula
return register_fun
@register_flop_formula(aten.mm)
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
"""Count flops for matmul."""
@ -652,6 +706,7 @@ class FlopCounterMode:
**flop_registry,
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
}
self.triton_flop_registry = {**triton_flop_registry}
self.mod_tracker = ModuleTracker()
def get_total_flops(self) -> int:
@ -752,7 +807,11 @@ class FlopCounterMode:
flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
for par in set(self.mod_tracker.parents):
self.flop_counts[par][func_packet] += flop_count
elif func_packet in self.triton_flop_registry:
flop_count_func = self.triton_flop_registry[func_packet]
flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
for par in set(self.mod_tracker.parents):
self.flop_counts[par][func_packet] += flop_count
return out
class _FlopCounterMode(TorchDispatchMode):
@ -782,15 +841,18 @@ class _FlopCounterMode(TorchDispatchMode):
return result, flop_counts
def _handle_higher_order_ops(self, func, types, args, kwargs):
if func not in {torch.ops.higher_order.cond, }:
return NotImplemented
# The flop counter for cond counts the upper bound of flops.
# For example, if a matmul is executed 2 times in true branch
# but only 1 time in the false branch, the flop counter will
# record the larger number of flops, i.e. 2 times.
if func is torch.ops.higher_order.cond:
is_triton = func in {torch.ops.higher_order.triton_kernel_wrapper_mutation,
torch.ops.higher_order.triton_kernel_wrapper_functional}
if is_triton:
from torch._higher_order_ops.triton_kernel_wrap import get_kernel
# Special case - look in the triton flop registry for the kernel
kernel_name = get_kernel(kwargs["kernel_idx"])
return self.counter._count_flops(kernel_name, None, args, kwargs)
elif func is torch.ops.higher_order.cond:
# The flop counter for cond counts the upper bound of flops.
# For example, if a matmul is executed 2 times in true branch
# but only 1 time in the false branch, the flop counter will
# record the larger number of flops, i.e. 2 times.
pred, true_branch, false_branch, operands = args
# Step 1: Count flops for true branch and false branch separately
true_out, true_flop_counts = self._execute_with_isolated_flop_counting(
@ -829,6 +891,8 @@ class _FlopCounterMode(TorchDispatchMode):
# It doesn't matter which one we return since true_fn and false_fn return
# output with the same structure.
return true_out
else:
return NotImplemented
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}