Update (base update)

[ghstack-poisoned]
This commit is contained in:
Oguz Ulgen
2025-11-07 15:57:34 -08:00
parent ed4aa449b6
commit 64c4cceb3f
11 changed files with 168 additions and 98 deletions

View File

@ -248,6 +248,12 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
CUDA_VERSION=13.0.0
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes
;;
pytorch-linux-jammy-py3.12-triton-cpu)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
@ -369,6 +375,7 @@ docker build \
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "HALIDE=${HALIDE}" \
--build-arg "PALLAS=${PALLAS}" \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \

View File

@ -0,0 +1 @@
0.8.0

View File

@ -0,0 +1,40 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
# Get the pinned JAX version (same for all CUDA versions)
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
function install_jax_12() {
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
}
function install_jax_13() {
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
}
# idiomatic parameter and option handling in sh
while test $# -gt 0
do
case "$1" in
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
;;
13.0|13.0.*) install_jax_13;
;;
*) echo "bad argument $1"; exit 1
;;
esac
shift
done

View File

@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
RUN rm install_halide.sh common_utils.sh halide.txt
ARG PALLAS
ARG CUDA_VERSION
# Install JAX with CUDA support (for Pallas)
COPY ./common/install_jax.sh install_jax.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
ARG ONNX
# Install ONNX dependencies
COPY ./common/install_onnx.sh ./common/common_utils.sh ./

View File

@ -824,6 +824,11 @@ test_inductor_halide() {
assert_git_not_dirty
}
test_inductor_pallas() {
python test/run_test.py --include inductor/test_pallas.py --verbose
assert_git_not_dirty
}
test_inductor_triton_cpu() {
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
assert_git_not_dirty
@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
test_inductor_halide
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
test_inductor_pallas
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then

View File

@ -65,6 +65,7 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,

View File

@ -81,6 +81,32 @@ jobs:
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
secrets: inherit
inductor-pallas-build:
name: inductor-pallas-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets: inherit
inductor-pallas-test:
name: inductor-pallas-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-pallas-build
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
secrets: inherit
inductor-triton-cpu-build:
name: inductor-triton-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -52,9 +52,12 @@ def make_pallas(cls):
return test_class
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTests(TestCase):
"""Basic tests for Pallas backend functionality."""
class PallasTestsMixin:
"""Basic tests for Pallas backend functionality (parameterized by DEVICE). Mixin only, not collected."""
def _compile(self, fn):
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
return torch.compile(fn, backend="inductor", options={key: "pallas"})
def test_simple_add(self):
"""Test basic element-wise addition."""
@ -62,12 +65,10 @@ class PallasTests(TestCase):
def fn(a, b):
return a + b
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -78,12 +79,10 @@ class PallasTests(TestCase):
def fn(a, b):
return a * b
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -94,11 +93,9 @@ class PallasTests(TestCase):
def fn(x):
return torch.sin(x)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda")
x = torch.randn(1024, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -109,12 +106,10 @@ class PallasTests(TestCase):
def fn(x, y):
return x.sin() + y
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda")
y = torch.randn(1024, device="cuda")
x = torch.randn(1024, device=self.DEVICE)
y = torch.randn(1024, device=self.DEVICE)
result = compiled(x, y)
expected = fn(x, y)
self.assertEqual(result, expected)
@ -125,11 +120,9 @@ class PallasTests(TestCase):
def fn(x):
return torch.log(torch.exp(x))
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda")
x = torch.randn(1024, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -140,11 +133,9 @@ class PallasTests(TestCase):
def fn(x):
return torch.sqrt(x)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt
x = torch.randn(1024, device=self.DEVICE).abs() # Ensure positive for sqrt
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -155,11 +146,9 @@ class PallasTests(TestCase):
def fn(x):
return torch.tanh(x)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda")
x = torch.randn(1024, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -170,11 +159,9 @@ class PallasTests(TestCase):
def fn(x):
return torch.abs(-x)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(1024, device="cuda")
x = torch.randn(1024, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -185,12 +172,10 @@ class PallasTests(TestCase):
def fn(a, b):
return torch.maximum(a, b) + torch.minimum(a, b)
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result, expected)
@ -228,15 +213,17 @@ class PallasTests(TestCase):
@torch.compile(
backend="inductor",
options={"cuda_backend": "pallas"},
options={
("cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"): "pallas"
},
)
def pallas_fn(a, b):
return a.sin() + b.cos()
_, (code,) = run_and_get_code(
pallas_fn,
torch.randn(64, device="cuda"),
torch.randn(64, device="cuda"),
torch.randn(64, device=self.DEVICE),
torch.randn(64, device=self.DEVICE),
)
# Verify Pallas-specific code generation
self.assertIn("import jax", code)
@ -249,12 +236,10 @@ class PallasTests(TestCase):
def fn(x, y):
return x + y
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
x = torch.randn(32, 32, device="cuda")
y = torch.randn(32, 32, device="cuda")
x = torch.randn(32, 32, device=self.DEVICE)
y = torch.randn(32, 32, device=self.DEVICE)
result = compiled(x, y)
expected = fn(x, y)
self.assertEqual(result, expected)
@ -265,12 +250,10 @@ class PallasTests(TestCase):
def fn(x):
return x * 2.0
compiled = torch.compile(
fn, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(fn)
for shape in [(64,), (128,), (256,), (1024,)]:
x = torch.randn(shape, device="cuda")
x = torch.randn(shape, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@ -282,12 +265,10 @@ class PallasTests(TestCase):
def contiguous_add(a, b):
return a + b
compiled = torch.compile(
contiguous_add, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(contiguous_add)
a = torch.randn(1024, device="cuda")
b = torch.randn(1024, device="cuda")
a = torch.randn(1024, device=self.DEVICE)
b = torch.randn(1024, device=self.DEVICE)
result = compiled(a, b)
expected = contiguous_add(a, b)
self.assertEqual(result, expected)
@ -296,44 +277,31 @@ class PallasTests(TestCase):
def contiguous_mul(x):
return x * 2.0
compiled = torch.compile(
contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(contiguous_mul)
x = torch.randn(128, 8, device="cuda")
x = torch.randn(128, 8, device=self.DEVICE)
result = compiled(x)
expected = contiguous_mul(x)
self.assertEqual(result, expected)
# Test 3: Non-contiguous views will fail at runtime with JAX/Pallas
# This demonstrates that the Pallas backend requires contiguous memory layout
# Test 3: Non-contiguous views should work with the simplified dlpack approach
# The direct dlpack conversion handles non-contiguous tensors correctly
def operate_on_tensor(x):
return x.sin()
compiled = torch.compile(
operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"}
)
compiled = self._compile(operate_on_tensor)
# Create a transposed (non-contiguous) view
x = torch.randn(64, 32, device="cuda")
x = torch.randn(64, 32, device=self.DEVICE)
x_t = x.t() # Non-contiguous view
self.assertFalse(x_t.is_contiguous())
# This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack
# The error indicates that our contiguous-only approach is correct
with self.assertRaises((RuntimeError, Exception)) as cm:
result = compiled(x_t)
# With the simplified dlpack approach, non-contiguous tensors now work
result = compiled(x_t)
expected = operate_on_tensor(x_t)
self.assertEqual(result, expected)
# Verify the error is related to layout/contiguous issues
error_msg = str(cm.exception)
self.assertTrue(
"layout" in error_msg.lower()
or "contiguous" in error_msg.lower()
or "non-default" in error_msg.lower(),
f"Expected layout/contiguous error, got: {error_msg}",
)
# But if we make it contiguous first, it should work
# Contiguous tensors should also continue to work
x_t_contiguous = x_t.contiguous()
self.assertTrue(x_t_contiguous.is_contiguous())
result = compiled(x_t_contiguous)
@ -341,13 +309,24 @@ class PallasTests(TestCase):
self.assertEqual(result, expected)
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTestsCUDA(PallasTestsMixin, TestCase):
DEVICE = "cuda"
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTestsCPU(PallasTestsMixin, TestCase):
DEVICE = "cpu"
# Create test variants using the main test suite
# Note: Only enable GPU tests since Pallas primarily targets GPU
if test_torchinductor.HAS_GPU and HAS_PALLAS:
# Uncomment these to run full test suite with Pallas backend
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
if getattr(test_torchinductor, "HAS_GPU", False):
# Uncomment these to run full test suite with Pallas backend
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if __name__ == "__main__":
if HAS_PALLAS:

View File

@ -521,6 +521,7 @@ def init_backend_registration() -> None:
"cpp": CppScheduling,
"halide": HalideScheduling,
"triton": TritonScheduling,
"pallas": PallasScheduling,
}
register_backend_for_device(
"cpu",

View File

@ -291,7 +291,6 @@ class PallasKernel(SIMDKernel):
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
""",
strip=True,
)
@ -314,6 +313,8 @@ class PallasKernel(SIMDKernel):
main_name = f"{kernel_name}_main"
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
with code.indent():
# Determine interpret statically based on codegen device
interpret_literal = "True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
# Identify inputs (in_ptr*) and output (out_ptr*)
input_params = [
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
@ -330,9 +331,7 @@ class PallasKernel(SIMDKernel):
# Convert inputs to JAX arrays
code.writeline("# Convert Torch -> JAX for inputs")
for inp in input_params:
code.writeline(
f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))"
)
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
# Get output spec from PyTorch tensor
code.writeline("# Prepare output spec from PyTorch tensor")
@ -351,9 +350,11 @@ class PallasKernel(SIMDKernel):
)
# Call pallas
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
code.writeline("compiled = pl.pallas_call(")
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")")
@ -362,9 +363,7 @@ class PallasKernel(SIMDKernel):
# Copy result back
code.writeline("# Copy result back into the provided torch output tensor")
code.writeline(
"res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))"
)
code.writeline("res_t = torch.from_dlpack(res)")
code.writeline(f"{output_param}.copy_(res_t)")
return code.getvalue()

View File

@ -1958,8 +1958,8 @@ class rocm:
contiguous_threshold: int = 16
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental)
cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp"
# Backend to use for CUDA codegen either
# "triton", "halide" (experimental) or "pallas" (experimental)