Compare commits

...

2 Commits

Author SHA1 Message Date
bc1a95a904 Update
[ghstack-poisoned]
2025-11-11 14:26:17 -08:00
feb88419cb Update (base update)
[ghstack-poisoned]
2025-11-11 14:26:17 -08:00
2 changed files with 59 additions and 9 deletions

View File

@ -436,9 +436,8 @@ if test_torchinductor.HAS_CPU and HAS_PALLAS:
if test_torchinductor.HAS_GPU and HAS_PALLAS:
# make_pallas(test_torchinductor.SweepInputsGPUTest)
make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if __name__ == "__main__":

View File

@ -504,6 +504,7 @@ class PallasKernel(SIMDKernel):
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch._inductor.runtime.runtime_utils import next_power_of_2
""",
strip=True,
)
@ -529,16 +530,65 @@ class PallasKernel(SIMDKernel):
code.writeline(str(line))
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1, 2))")
code.writeline(
f"def {jit_wrapper_name}(out_shape, out_dtype, device_type, *kernel_refs):"
)
with code.indent():
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
code.writeline("return pl.pallas_call(")
code.writeline(
"# On GPU (Triton backend), Pallas requires each dimension to be power of 2"
)
code.writeline("if device_type == 'cuda':")
with code.indent():
code.writeline("# Pad each dimension of each input to power of 2")
code.writeline("kernel_refs_padded = []")
code.writeline("for ref in kernel_refs:")
with code.indent():
code.writeline(
"padded_shape = tuple(next_power_of_2(d) for d in ref.shape)"
)
code.writeline(
"pad_widths = [(0, padded_shape[i] - ref.shape[i]) for i in range(len(ref.shape))]"
)
code.writeline(
"padded_ref = jnp.pad(ref, pad_widths, mode='constant', constant_values=0)"
)
code.writeline("kernel_refs_padded.append(padded_ref)")
code.writeline("kernel_refs_padded = tuple(kernel_refs_padded)")
code.writeline("")
code.writeline("# Pad output shape to power of 2")
code.writeline(
"padded_out_shape = tuple(next_power_of_2(d) for d in out_shape)"
)
code.writeline("else:")
with code.indent():
code.writeline("# On CPU, no padding needed")
code.writeline("kernel_refs_padded = kernel_refs")
code.writeline("padded_out_shape = out_shape")
code.writeline("")
code.writeline("# Output spec with padded shape")
code.writeline(
"out_spec = jax.ShapeDtypeStruct(padded_out_shape, out_dtype)"
)
code.writeline("")
code.writeline("result_padded = pl.pallas_call(")
code.writeline(f" {kernel_name}_kernel,")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")(*kernel_refs)")
code.writeline(")(*kernel_refs_padded)")
code.writeline("")
code.writeline("# Slice back to original shape")
code.writeline("if device_type == 'cuda':")
with code.indent():
code.writeline("# Build slice for each dimension")
code.writeline(
"slices = tuple(slice(0, out_shape[i]) for i in range(len(out_shape)))"
)
code.writeline("return result_padded[slices]")
code.writeline("else:")
with code.indent():
code.writeline("return result_padded")
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
main_name = f"{kernel_name}_main"
@ -560,7 +610,7 @@ class PallasKernel(SIMDKernel):
output_param = output_params[0]
# Convert inputs to JAX arrays
# Convert inputs to JAX arrays (keep original shapes for broadcasting)
for inp in input_params:
code.writeline(
f"{inp}_jax = jax.dlpack.from_dlpack({inp}.contiguous())"
@ -570,8 +620,9 @@ class PallasKernel(SIMDKernel):
code.writeline("# Prepare output metadata from PyTorch tensor")
code.writeline(f"out_shape = tuple({output_param}.shape)")
code.writeline(f"out_dtype = {output_dtype_jax}")
code.writeline(f"device_type = {output_param}.device.type")
call_args = ["out_shape", "out_dtype"] + [
call_args = ["out_shape", "out_dtype", "device_type"] + [
f"{inp}_jax" for inp in input_params
]
call_arg_str = ", ".join(call_args)