Compare commits

...

8 Commits

12 changed files with 340 additions and 149 deletions

View File

@ -388,6 +388,47 @@ def forward(self, b_parametrizations_buffer_original0, x):
res = opt_fn(x, y)
self.assertEqual(res, ref)
def test_dtensor_dynamic_recompiles(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
def inp(*shape):
param = torch.randn(*shape, requires_grad=True)
x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return x
def run(func, *shape):
res = func(inp(*shape))
res.sum().backward()
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
y = x * x
return y.to_local()
run(f, 4, 4)
run(f, 6, 8)
run(f, 10, 10)
self.assertEqual(cnt.frame_count, 1)
# sanity check that shape guard recompiles are still handled
@torch.compile(backend=cnt, fullgraph=True)
def g(x):
if x.size(0) <= 16:
y = x * x
else:
y = x + x
return y.to_local()
cnt.clear()
run(g, 4, 4)
run(g, 8, 8)
self.assertEqual(cnt.frame_count, 1)
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -36,6 +36,17 @@ class SimpleModel(torch.nn.Module):
def forward(self, input):
return self.mlp_1(self.mlp_0(input))
class SimpleModelDynamicShapes(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
if input.shape[0] > 4:
return self.mlp_0(input.sin())
return self.mlp_1(input.cos())
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
@ -150,6 +161,37 @@ class DTensorExportTest(TestCase):
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
def test_dynamic_shapes(self):
dp_degree = 2
tp_degree = self.world_size // dp_degree
# 2-D mesh is [dp, tp]
mesh_2d = init_device_mesh(
self.device_type,
mesh_shape=(dp_degree, tp_degree),
mesh_dim_names=["dp", "tp"],
)
model = SimpleModelDynamicShapes(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
torch.compile(tp_model, fullgraph=True)(inputs)
# joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
# fw_gm, bw_gm = min_cut_rematerialization_partition(
# joint_gm, None, num_fwd_outputs=1
# )
# print(fw_gm, bw_gm)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -1702,7 +1702,11 @@ class SIMDScheduling(BaseScheduling):
from ..ir import IRNode
def get_size(arg):
if not isinstance(arg, IRNode) or (size := arg.maybe_get_size()) is None:
if not isinstance(arg, IRNode):
return None
if isinstance(arg, ir.BaseView): # triton templates want the base tensor.
arg = arg.unwrap_view()
if (size := arg.maybe_get_size()) is None:
return None
return tuple(s for s in size)
@ -1725,21 +1729,21 @@ class SIMDScheduling(BaseScheduling):
)
def _make_shape_cache_key(
self, node: MultiTemplateBuffer, hint: int
self, node: MultiTemplateBuffer, hint_key: Any
) -> tuple[tuple[int, ...], ...]:
"""
Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
"""
shapes = self._get_multikernel_shapes(node)
return tuple(
hints = {k: v for k, v in hint_key}
out = tuple(
tuple(
hint
if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
else s
V.graph.sizevars.size_hint(s.subs(hints))
for s in shape
)
for shape in shapes
)
return out
def codegen_template(
self,
@ -1770,41 +1774,47 @@ class SIMDScheduling(BaseScheduling):
src_codes = []
for (
size_hint,
size_hint_key,
make_kernel_render,
) in template_node.node._make_kernel_renders.items():
kernel, render = make_kernel_render(
template_node.node, hint_override=hint_override
)
if only_gen_src_code:
src_code = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=True,
)
assert isinstance(src_code, str)
src_codes.append(src_code)
if size_hint_key is None:
ctx = contextlib.nullcontext()
else:
if size_hint is None:
continue # skip kernel generation based on real runtime value; only use hints
kernel = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=False,
size_hint_overrides = {k: v for k, v in size_hint_key}
ctx = V.graph.sizevars.set_hint_overrides(size_hint_overrides)
with ctx:
kernel, render = make_kernel_render(
template_node.node, hint_override=hint_override
)
shape_cache_key = (
None
if size_hint is None
else self._make_shape_cache_key(template_node.node, size_hint)
)
kernels[shape_cache_key] = kernel
if only_gen_src_code:
src_code = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=True,
)
assert isinstance(src_code, str)
src_codes.append(src_code)
else:
if size_hint_key is None:
continue # skip kernel generation based on real runtime value; only use hints
kernel = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=False,
)
shape_cache_key = (
None
if size_hint_key is None
else self._make_shape_cache_key(template_node.node, size_hint_key)
)
kernels[shape_cache_key] = kernel
if only_gen_src_code:
return "\n\n".join(src_codes)

View File

@ -4034,12 +4034,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
args.append(str(arg))
elif isinstance(arg, SymbolicCallArg):
hint = V.graph.sizevars.size_hint(
arg.inner_expr, fallback=config.unbacked_symint_fallback
arg.inner_expr,
hint_override=self.hint_override,
fallback=config.unbacked_symint_fallback,
)
args.append(str(hint))
elif isinstance(arg, sympy.Expr):
hint = V.graph.sizevars.size_hint(
arg, fallback=config.unbacked_symint_fallback
arg,
hint_override=self.hint_override,
fallback=config.unbacked_symint_fallback,
)
args.append(str(hint))
else:
@ -4110,7 +4114,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
result.writeline(f"{var_name} = {symval_hint}")
elif isinstance(arg_sig, WorkspaceArg):
device = V.graph.get_current_device_or_throw()
count = V.graph.sizevars.size_hint(arg_sig.count)
count = V.graph.sizevars.size_hint(
arg_sig.count, hint_override=self.hint_override
)
result.writeline(
f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
)

View File

@ -5081,7 +5081,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
self, hint_override: Optional[int] = None
) -> dict[ChoiceCaller, float]:
if hint_override not in self._choice_timings:
self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
self._choice_timings[hint_override] = self._choice_timings_fn()
return self._choice_timings[hint_override]
@contextlib.contextmanager
@ -5114,14 +5114,14 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
return (min_choice, timings[min_choice])
def finalize_as_triton_callers(
self, callers: dict[Optional[int], TritonTemplateCallerBase]
self, callers: dict[Any, TritonTemplateCallerBase]
) -> None:
"""Finalize with multiple callers for different hint overrides"""
for hint_override, caller in callers.items():
self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
for hint_override_key, caller in callers.items():
self._make_kernel_renders[hint_override_key] = caller.get_make_kernel_render()
# Set the default to be the one without hint override
self.make_kernel_render = self._make_kernel_renders[None]
# # Set the default to be the one without hint override
# self.make_kernel_render = self._make_kernel_renders[None]
class CUDATemplateBuffer(TemplateBuffer):

View File

@ -3144,19 +3144,38 @@ class Scheduler:
min_node_unfused,
torch._inductor.ir.TritonTemplateCallerBase,
):
if config.multi_kernel_hints:
callers: dict[Optional[int], TritonTemplateCallerBase] = {}
callers[None] = min_node_unfused
callers: dict[Optional[int], TritonTemplateCallerBase] = {}
callers[None] = min_node_unfused
for hint in config.multi_kernel_hints:
timings = multi_node.choice_timings(hint_override=hint)
free_symbols = set()
for inp in multi_node.inputs:
if (size := inp.maybe_get_size()) is not None:
for s in size:
free_symbols |= s.free_symbols
for s in multi_node.layout.size:
free_symbols |= s.free_symbols
overrides = {}
for sym in free_symbols:
if sym in V.graph.sizevars.var_to_hint_override_:
overrides[sym] = V.graph.sizevars.var_to_hint_override_[sym]
if overrides:
for override_vals in itertools.product(*[overrides[k] for k in free_symbols if k in overrides]):
override_vals = tuple(override_vals)
overrides_ = dict(zip(free_symbols, override_vals))
override_key = tuple((s, v) for s, v in zip(free_symbols, override_vals))
# with V.graph.sizevars.set_hint_overrides(overrides_):
timings = multi_node.choice_timings(override_key)
triton_timings = {
k: v
for k, v in timings.items()
if isinstance(k, TritonTemplateCallerBase)
}
choice = min(triton_timings.items(), key=lambda x: x[1])[0]
callers[hint] = choice
callers[override_key] = choice
node.node.finalize_as_triton_callers(callers)
else:
@ -3308,55 +3327,69 @@ class Scheduler:
)
assert isinstance(multi_node, ir.MultiTemplateBuffer)
hint_override_best_fusion_choice: dict[
Optional[int], TritonTemplateCallerBase
] = {}
hint_override_best_fusion_choice = {}
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
for hint_override in config.multi_kernel_hints:
choice_timings = multi_node.choice_timings(hint_override)
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
):
if not isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
):
continue
with multi_node.swap_as_triton_caller(choice):
future_choices.append(
(
choice,
*compile_kernel(
node_list_fused, hint_override=choice.hint_override
),
)
)
min_ms_fused = float("inf")
ms_fused_choice: Optional[TritonTemplateCallerBase] = None
new_timings = {}
for choice, future, mod_fused in future_choices:
try:
if future is not None:
future.result()
except Exception as e:
if fusion_log.isEnabledFor(logging.DEBUG):
fusion_log.debug(
"Exception in compiling %s: %s",
"prologue" if not epilogue_fusion else "epilogue",
str(e),
free_symbols = set()
for inp in multi_node.inputs:
if (size := inp.maybe_get_size()) is not None:
for s in size:
free_symbols |= s.free_symbols
for s in multi_node.layout.size:
free_symbols |= s.free_symbols
overrides = {}
for sym in free_symbols:
if sym in V.graph.sizevars.var_to_hint_override_:
overrides[sym] = V.graph.sizevars.var_to_hint_override_[sym]
for override_vals in itertools.product(*[overrides[k] for k in free_symbols if k in overrides]):
override_vals = tuple(override_vals)
overrides_ = dict(zip(free_symbols, override_vals))
override_key = tuple((s, v) for s, v in zip(free_symbols, override_vals))
with V.graph.sizevars.set_hint_overrides(overrides_):
choice_timings = multi_node.choice_timings(override_key)
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
):
if not isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
):
continue
with multi_node.swap_as_triton_caller(choice):
future_choices.append(
(
choice,
*compile_kernel(node_list_fused),
)
)
continue
with multi_node.swap_as_triton_caller(choice):
ms_fused, path = self.benchmark_codegened_module(
mod_fused, device
)
new_timings[choice] = ms_fused
if ms_fused < min_ms_fused:
min_ms_fused = ms_fused
ms_fused_choice = choice
multi_node._choice_timings[hint_override] = new_timings
assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
hint_override_best_fusion_choice[hint_override] = ms_fused_choice
min_ms_fused = float("inf")
ms_fused_choice: Optional[TritonTemplateCallerBase] = None
new_timings = {}
for choice, future, mod_fused in future_choices:
try:
if future is not None:
future.result()
except Exception as e:
if fusion_log.isEnabledFor(logging.DEBUG):
fusion_log.debug(
"Exception in compiling %s: %s",
"prologue" if not epilogue_fusion else "epilogue",
str(e),
)
continue
with multi_node.swap_as_triton_caller(choice):
ms_fused, path = self.benchmark_codegened_module(
mod_fused, device
)
new_timings[choice] = ms_fused
if ms_fused < min_ms_fused:
min_ms_fused = ms_fused
ms_fused_choice = choice
multi_node._choice_timings[override_key] = new_timings
assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
hint_override_best_fusion_choice[override_key] = ms_fused_choice
# Eagerly compile and benchmark non-template nodes
choice_timings = multi_node.choice_timings()
@ -3435,7 +3468,6 @@ class Scheduler:
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
if config.multi_kernel_hints:
hint_override_best_fusion_choice[None] = ms_fused_choice
multi_node.finalize_as_triton_callers(
hint_override_best_fusion_choice
)

View File

@ -3,6 +3,7 @@ import functools
import itertools
import logging
from collections.abc import Iterable, Sequence
from contextlib import contextmanager
from typing import Any, Callable, cast, Optional, Union
import sympy
@ -55,6 +56,9 @@ def statically_known_true(
return False
SET_COUNT = 0
REACHED_3200 = False
# This class is a little awkward, because ShapeEnv is doing most of the heavy
# lifting and in some cases we should be directly passing through to ShapeEnv,
# but there is some extra inductor logic that needs to be handled here
@ -76,7 +80,8 @@ class SizeVarAllocator:
shape_env = ShapeEnv()
self.shape_env = shape_env
self.var_to_val = self.shape_env.var_to_val
self.var_to_hint_override = self.shape_env.var_to_hint_override
self.var_to_hint_override_ = self.shape_env.var_to_hint_override
self.var_to_hint_override = {}
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
@ -94,6 +99,30 @@ class SizeVarAllocator:
self.stride_vars = self.make_stride_vars_cache()
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
self._simplify_loops = self.make_simplify_loops_cache()
self.active_hint_override = False
@contextmanager
def set_hint_overrides(self, overrides):
# global SET_COUNT, REACHED_3200
# if SET_COUNT == 3200:
# REACHED_3200 = True
# SET_COUNT += 1
# print(f"set_hint_overrides: {overrides}")
old_overrides = self.var_to_hint_override
new_overrides = {}
for k, v in overrides.items():
new_overrides[k] = v
if k in self.replacements:
kr = self.replacements[k]
if isinstance(kr, sympy.Symbol):
new_overrides[kr] = v
try:
self.active_hint_override = True
self.var_to_hint_override = {**old_overrides, **new_overrides}
yield
finally:
self.active_hint_override = False
self.var_to_hint_override = old_overrides
def simplify(self, expr: Expr):
return sympy.expand(expr).xreplace(self.replacements)
@ -559,6 +588,8 @@ class SizeVarAllocator:
return expr
# Substitute all hints into expr, but leave unbacked symints alone
expr = self.simplify(expr)
expr = self.remove_precomputed_replacements(expr)
# print(f"symbolic_hint (processed): {expr}")
if not isinstance(expr, Expr):
assert isinstance(expr, int)
return expr
@ -570,9 +601,9 @@ class SizeVarAllocator:
return expr # inf/nan/I
if hint_override:
return hint_override
expr = self.remove_precomputed_replacements(expr)
out = expr.subs({symbol: hint_override for symbol in free_symbols})
assert isinstance(out, sympy.Integer)
return out
if use_user_provided_hint_override:
expr = sympy_subs(expr, self.var_to_hint_override)
@ -586,10 +617,18 @@ class SizeVarAllocator:
fallback: Optional[int] = None,
hint_override: Optional[int] = None,
) -> int:
# print(f"size_hint {expr}, fallback={fallback}, hint_override={hint_override}")
# print(f"var_to_val: {self.var_to_val}")
# print(f"var_to_hint_override: {self.var_to_hint_override}, active: {self.active_hint_override}")
# if REACHED_3200:
# print(f"size_hint {expr}, fallback={fallback}, hint_override={hint_override}")
# print(f"var_to_val: {self.var_to_val}")
# print(f"var_to_hint_override: {self.var_to_hint_override}, active: {self.active_hint_override}")
# breakpoint()
out = self.symbolic_hint(
expr,
hint_override=hint_override,
use_user_provided_hint_override=fallback is not None,
use_user_provided_hint_override=fallback is not None or self.active_hint_override,
)
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
# Use the provided heuristic fallback hint

View File

@ -582,51 +582,52 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16
scaled_configs = []
for hint_override in [None] + config.multi_kernel_hints:
m_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size,
)
n_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size,
)
k_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size_k,
)
for c in configs:
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m_hint), min_block_size),
block_n=max(min(int(c.block_n * scale), n_hint), min_block_size),
block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k),
hint_override=hint_override,
free_symbols = set()
for expr in [m, n, k]:
free_symbols |= expr.free_symbols
free_symbols = list(free_symbols)
import itertools
overrides = {}
for sym in free_symbols:
if sym in V.graph.sizevars.var_to_hint_override_:
overrides[sym] = V.graph.sizevars.var_to_hint_override_[sym]
for override_vals in itertools.product(*[overrides[k] for k in free_symbols if k in overrides]):
overrides_ = dict(zip(free_symbols, override_vals))
with V.graph.sizevars.set_hint_overrides(overrides_):
m_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(m)
),
min_block_size,
)
n_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(n)
),
min_block_size,
)
k_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(k)
),
min_block_size_k,
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
for c in configs:
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m_hint), min_block_size),
block_n=max(min(int(c.block_n * scale), n_hint), min_block_size),
block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k),
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs

View File

@ -1394,6 +1394,8 @@ class IndentedBuffer:
self.writeline("\n")
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None:
# if "xindex = idx_n + 3584*idx_m" in line:
# breakpoint()
if isinstance(line, LineContext):
self._lines.append(line)
elif isinstance(line, DeferredLineBase):

View File

@ -651,6 +651,17 @@ class DTensor(torch.Tensor):
else:
raise RuntimeError("Unsupported tensor type!")
@classmethod
def __metadata_guard__(
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
) -> bool:
orig_spec, orig_requires_grad = orig
other_spec, other_requires_grad = other
return (
orig_spec._check_equals(other_spec, skip_shapes=True)
and orig_requires_grad == other_requires_grad
)
def distribute_tensor(
tensor: torch.Tensor,

View File

@ -78,7 +78,7 @@ class DTensorSpec:
self._hash = self._hash_impl()
return self._hash
def __eq__(self, other: object, /) -> bool:
def _check_equals(self, other: object, skip_shapes: bool = False) -> bool:
if not (
isinstance(other, DTensorSpec)
and self.mesh == other.mesh
@ -88,12 +88,17 @@ class DTensorSpec:
if self.tensor_meta is None or other.tensor_meta is None:
return self.tensor_meta == other.tensor_meta
if skip_shapes:
return self.tensor_meta.dtype == other.tensor_meta.dtype
return (
self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr]
and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr]
and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr]
)
def __eq__(self, other: object, /) -> bool:
return self._check_equals(other)
def __str__(self) -> str:
"""
human readable representation of the DTensorSpec

View File

@ -4378,7 +4378,8 @@ class ShapeEnv:
size = []
for i, val in enumerate(tensor_size):
sym = self.create_symbol(
val if i not in hint_overrides else hint_overrides[i],
# val if i not in hint_overrides else hint_overrides[i],
val,
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
@ -4579,7 +4580,8 @@ class ShapeEnv:
sym_sizes = [
self.create_symintnode(
sym,
hint=hint if i not in hint_overrides else hint_overrides[i],
# hint=hint if i not in hint_overrides else hint_overrides[i],
hint=hint,
source=TensorPropertySource(source, TensorProperty.SIZE, i),
)
for i, (sym, hint) in enumerate(zip(size, ex_size))