Compare commits

...

1 Commits

Author SHA1 Message Date
b4310afe8b Hack city 2024-07-02 19:17:43 -07:00
7 changed files with 87 additions and 31 deletions

View File

@ -418,9 +418,9 @@ def cast_to(dtype, model, inputs):
model = cast_dtype_args_to_fp64(model)
inputs = tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
lambda x: (
x.to(dtype) if isinstance(x, torch.Tensor) and x.is_floating_point() else x
),
inputs,
)
return model, inputs
@ -733,14 +733,15 @@ def aot_graph_input_parser(
# Resolve symbolic shapes to concrete values
resolved_shape = []
dynamic_dims = []
for i, dim in enumerate(shape):
dim = dim.strip()
if "s" in dim:
s = get_sym_int(dim)
resolved_shape.append(s)
dynamic_dims.append(i)
else:
resolved_shape.append(int(dim))
if shape != ("",):
for i, dim in enumerate(shape):
dim = dim.strip()
if "s" in dim:
s = get_sym_int(dim)
resolved_shape.append(s)
dynamic_dims.append(i)
else:
resolved_shape.append(int(dim))
constructor = torch.randn if dtype.is_floating_point else torch.zeros
out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]

View File

@ -1147,6 +1147,9 @@ class KernelArgs:
@staticmethod
def _lookup(prefix, odict, name):
assert isinstance(name, (str, sympy.Symbol))
if name == "buf394":
breakpoint()
print("LOOKED UP IN KERNEL ARGS")
if name not in odict:
odict[name] = f"{prefix}{len(odict)}"
return odict[name]
@ -1842,6 +1845,7 @@ class Kernel(CodeGen):
@staticmethod
def load(name: str, index: sympy.Expr) -> CSEVariable:
# print(name)
if name in self.cse.invalidated_stores:
# A load from an invalidated store requires us to
# keep the actual buffer around

View File

@ -31,6 +31,7 @@ import torch._logging
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash
@ -845,23 +846,31 @@ class SIMDKernel(Kernel):
log.warning(msg)
stride_order_list = [
ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
if V.graph.get_buffer(name)
else None
(
ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
if V.graph.get_buffer(name)
else None
)
for name in call_args
]
size_list = [
V.graph.get_buffer(name).layout.size
if V.graph.get_buffer(name)
else None
(
V.graph.get_buffer(name).layout.size
if V.graph.get_buffer(name)
else None
)
for name in call_args
]
source_list = [
"GraphInput"
if name in V.graph.graph_inputs
else "IntermediateBuffer"
if name in V.graph.name_to_buffer
else None
(
"GraphInput"
if name in V.graph.graph_inputs
else (
"IntermediateBuffer"
if name in V.graph.name_to_buffer
else None
)
)
for name in call_args
]
@ -1449,9 +1458,14 @@ class SIMDScheduling(BaseScheduling):
def codegen_foreach(self, foreach_node):
from .triton_foreach import ForeachKernel
i = 0
for partitions_with_metadata in ForeachKernel.horizontal_partition(
foreach_node.get_subkernel_nodes(), self
):
i += 1
if i == 2:
print("----------------------------------------")
print(len(partitions_with_metadata))
kernel = ForeachKernel()
for nodes, tiled_groups, numel, rnumel in partitions_with_metadata:
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
@ -1475,6 +1489,7 @@ class SIMDScheduling(BaseScheduling):
with V.set_kernel_handler(subkernel):
for node in node_schedule:
# print(f"{node.get_name()}")
if node not in (EnableReduction, DisableReduction):
node.mark_run()
V.graph.removed_buffers |= subkernel.removed_buffers
@ -1482,8 +1497,14 @@ class SIMDScheduling(BaseScheduling):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, [foreach_node], kernel)
if i == 2:
print(kernel_name)
print(kernel.args)
# print(kernel.args.python_argdefs())
self.codegen_comment([foreach_node])
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
if i == 2:
print("----------------------------------------------")
self.scheduler.free_buffers()

View File

@ -1744,6 +1744,7 @@ class TritonKernel(SIMDKernel):
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
print(name)
var = self.args.output(name)
original_index = index
indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None)

View File

@ -148,6 +148,7 @@ class ForeachKernel(Kernel):
self.blocking_2d |= groups[1] != 1 and len(groups) == 3
metrics.generated_kernel_count -= 1
sub_kernel.args = self.args
# print(sub_kernel.args)
sub_kernel.iter_vars_count = self.iter_vars_count
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
self.sub_kernels.append(sub_kernel)
@ -181,9 +182,11 @@ class ForeachKernel(Kernel):
def grid(self):
return (
self.x_block_count,
ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
if self.blocking_2d
else 1,
(
ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
if self.blocking_2d
else 1
),
1,
)

View File

@ -1404,6 +1404,7 @@ class WrapperCodeGen(CodeGen):
device_index, call_args = self.prepare_triton_kernel_call(
device_index, call_args
)
print(call_args)
call_args_str = ", ".join(call_args)
stream_name = self.write_get_raw_stream(device_index, V.graph)
if triton:

View File

@ -828,6 +828,7 @@ class SchedulerNode(BaseSchedulerNode):
self.codegen(index_vars)
def mark_run(self) -> None:
# print(f"allocating {self.get_name()}")
self.allocate()
def ranges_from_index_vars(
@ -1091,15 +1092,27 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
def get_producer_subnode_for(
self, consumer: BaseSchedulerNode
) -> Optional[BaseSchedulerNode]:
producers = []
for rd in consumer.read_writes.reads:
if rd.name in self.name_to_node:
return self.name_to_node[rd.name]
producers.append(self.name_to_node[rd.name])
return None
# Don't permit fusion if there are multiple subnodes
# that this consumer reads from
if len(producers) == 1:
return producers[0]
else:
return None
@classmethod
def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
why = WhyNoFuse(producer, consumer)
if producer.get_name() == "buf486_buf526":
breakpoint()
if consumer.get_name() == "buf486_buf526":
breakpoint()
if producer.is_foreach() and consumer.is_foreach():
producer = typing.cast(ForeachKernelSchedulerNode, producer)
consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
@ -1197,15 +1210,15 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
prev_node_1: Optional[BaseSchedulerNode] = None,
prev_node_2: Optional[BaseSchedulerNode] = None,
) -> None:
self.read_to_node = {}
self.name_to_node = {}
self.read_to_node = collections.defaultdict(list)
self.name_to_node = collections.defaultdict(list)
if prev_node_1 is None or prev_node_2 is None:
super().__init__(scheduler, nodes)
for node in nodes:
for read in node.read_writes.reads:
self.read_to_node[read.name] = node
self.read_to_node[read.name].append(node)
for name in node.get_names():
self.name_to_node[name] = node
@ -1783,7 +1796,6 @@ class Scheduler:
updated_nodes.append(node)
else:
# dead code
log.debug("removed dead node: %s", node.get_name())
V.graph.removed_buffers.add(node.get_name())
again = len(self.nodes) > len(updated_nodes)
@ -1846,6 +1858,14 @@ class Scheduler:
)
self.fuse_nodes_once()
new_len = len(self.nodes)
print(f"{i + 1}----------------------")
for node in self.nodes:
for read in node.read_writes.reads:
if read.name == "buf394":
print(node)
print("----------------------")
fusion_log.debug(
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
i + 1,
@ -2087,6 +2107,11 @@ class Scheduler:
for node1, node2 in self.get_possible_fusions():
node1 = self.name_to_fused_node[node1.get_first_name()]
node2 = self.name_to_fused_node[node2.get_first_name()]
if node1.get_name() == "buf486_buf526":
breakpoint()
if node2.get_name() == "buf486_buf526":
breakpoint()
if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
node1, node2
):