mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -248,6 +248,12 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
|
||||
CUDA_VERSION=13.0.0
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -369,6 +375,7 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
@ -0,0 +1 @@
|
||||
0.8.0
|
||||
40
.ci/docker/common/install_jax.sh
Executable file
40
.ci/docker/common/install_jax.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -824,6 +824,11 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -65,6 +65,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,6 +81,32 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
@ -52,9 +52,12 @@ def make_pallas(cls):
|
||||
return test_class
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTests(TestCase):
|
||||
"""Basic tests for Pallas backend functionality."""
|
||||
class PallasTestsMixin:
|
||||
"""Basic tests for Pallas backend functionality (parameterized by DEVICE). Mixin only, not collected."""
|
||||
|
||||
def _compile(self, fn):
|
||||
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||
return torch.compile(fn, backend="inductor", options={key: "pallas"})
|
||||
|
||||
def test_simple_add(self):
|
||||
"""Test basic element-wise addition."""
|
||||
@ -62,12 +65,10 @@ class PallasTests(TestCase):
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -78,12 +79,10 @@ class PallasTests(TestCase):
|
||||
def fn(a, b):
|
||||
return a * b
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -94,11 +93,9 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -109,12 +106,10 @@ class PallasTests(TestCase):
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
y = torch.randn(1024, device="cuda")
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
y = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(x, y)
|
||||
expected = fn(x, y)
|
||||
self.assertEqual(result, expected)
|
||||
@ -125,11 +120,9 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return torch.log(torch.exp(x))
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -140,11 +133,9 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt
|
||||
x = torch.randn(1024, device=self.DEVICE).abs() # Ensure positive for sqrt
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -155,11 +146,9 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return torch.tanh(x)
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -170,11 +159,9 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return torch.abs(-x)
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(1024, device="cuda")
|
||||
x = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -185,12 +172,10 @@ class PallasTests(TestCase):
|
||||
def fn(a, b):
|
||||
return torch.maximum(a, b) + torch.minimum(a, b)
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(a, b)
|
||||
expected = fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -228,15 +213,17 @@ class PallasTests(TestCase):
|
||||
|
||||
@torch.compile(
|
||||
backend="inductor",
|
||||
options={"cuda_backend": "pallas"},
|
||||
options={
|
||||
("cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"): "pallas"
|
||||
},
|
||||
)
|
||||
def pallas_fn(a, b):
|
||||
return a.sin() + b.cos()
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
pallas_fn,
|
||||
torch.randn(64, device="cuda"),
|
||||
torch.randn(64, device="cuda"),
|
||||
torch.randn(64, device=self.DEVICE),
|
||||
torch.randn(64, device=self.DEVICE),
|
||||
)
|
||||
# Verify Pallas-specific code generation
|
||||
self.assertIn("import jax", code)
|
||||
@ -249,12 +236,10 @@ class PallasTests(TestCase):
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(32, 32, device="cuda")
|
||||
y = torch.randn(32, 32, device="cuda")
|
||||
x = torch.randn(32, 32, device=self.DEVICE)
|
||||
y = torch.randn(32, 32, device=self.DEVICE)
|
||||
result = compiled(x, y)
|
||||
expected = fn(x, y)
|
||||
self.assertEqual(result, expected)
|
||||
@ -265,12 +250,10 @@ class PallasTests(TestCase):
|
||||
def fn(x):
|
||||
return x * 2.0
|
||||
|
||||
compiled = torch.compile(
|
||||
fn, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(fn)
|
||||
|
||||
for shape in [(64,), (128,), (256,), (1024,)]:
|
||||
x = torch.randn(shape, device="cuda")
|
||||
x = torch.randn(shape, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
@ -282,12 +265,10 @@ class PallasTests(TestCase):
|
||||
def contiguous_add(a, b):
|
||||
return a + b
|
||||
|
||||
compiled = torch.compile(
|
||||
contiguous_add, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(contiguous_add)
|
||||
|
||||
a = torch.randn(1024, device="cuda")
|
||||
b = torch.randn(1024, device="cuda")
|
||||
a = torch.randn(1024, device=self.DEVICE)
|
||||
b = torch.randn(1024, device=self.DEVICE)
|
||||
result = compiled(a, b)
|
||||
expected = contiguous_add(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
@ -296,44 +277,31 @@ class PallasTests(TestCase):
|
||||
def contiguous_mul(x):
|
||||
return x * 2.0
|
||||
|
||||
compiled = torch.compile(
|
||||
contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(contiguous_mul)
|
||||
|
||||
x = torch.randn(128, 8, device="cuda")
|
||||
x = torch.randn(128, 8, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = contiguous_mul(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test 3: Non-contiguous views will fail at runtime with JAX/Pallas
|
||||
# This demonstrates that the Pallas backend requires contiguous memory layout
|
||||
# Test 3: Non-contiguous views should work with the simplified dlpack approach
|
||||
# The direct dlpack conversion handles non-contiguous tensors correctly
|
||||
def operate_on_tensor(x):
|
||||
return x.sin()
|
||||
|
||||
compiled = torch.compile(
|
||||
operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"}
|
||||
)
|
||||
compiled = self._compile(operate_on_tensor)
|
||||
|
||||
# Create a transposed (non-contiguous) view
|
||||
x = torch.randn(64, 32, device="cuda")
|
||||
x = torch.randn(64, 32, device=self.DEVICE)
|
||||
x_t = x.t() # Non-contiguous view
|
||||
self.assertFalse(x_t.is_contiguous())
|
||||
|
||||
# This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack
|
||||
# The error indicates that our contiguous-only approach is correct
|
||||
with self.assertRaises((RuntimeError, Exception)) as cm:
|
||||
result = compiled(x_t)
|
||||
# With the simplified dlpack approach, non-contiguous tensors now work
|
||||
result = compiled(x_t)
|
||||
expected = operate_on_tensor(x_t)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Verify the error is related to layout/contiguous issues
|
||||
error_msg = str(cm.exception)
|
||||
self.assertTrue(
|
||||
"layout" in error_msg.lower()
|
||||
or "contiguous" in error_msg.lower()
|
||||
or "non-default" in error_msg.lower(),
|
||||
f"Expected layout/contiguous error, got: {error_msg}",
|
||||
)
|
||||
|
||||
# But if we make it contiguous first, it should work
|
||||
# Contiguous tensors should also continue to work
|
||||
x_t_contiguous = x_t.contiguous()
|
||||
self.assertTrue(x_t_contiguous.is_contiguous())
|
||||
result = compiled(x_t_contiguous)
|
||||
@ -341,13 +309,24 @@ class PallasTests(TestCase):
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTestsCUDA(PallasTestsMixin, TestCase):
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTestsCPU(PallasTestsMixin, TestCase):
|
||||
DEVICE = "cpu"
|
||||
|
||||
|
||||
# Create test variants using the main test suite
|
||||
# Note: Only enable GPU tests since Pallas primarily targets GPU
|
||||
if test_torchinductor.HAS_GPU and HAS_PALLAS:
|
||||
# Uncomment these to run full test suite with Pallas backend
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
|
||||
if getattr(test_torchinductor, "HAS_GPU", False):
|
||||
# Uncomment these to run full test suite with Pallas backend
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_PALLAS:
|
||||
|
||||
@ -521,6 +521,7 @@ def init_backend_registration() -> None:
|
||||
"cpp": CppScheduling,
|
||||
"halide": HalideScheduling,
|
||||
"triton": TritonScheduling,
|
||||
"pallas": PallasScheduling,
|
||||
}
|
||||
register_backend_for_device(
|
||||
"cpu",
|
||||
|
||||
@ -291,7 +291,6 @@ class PallasKernel(SIMDKernel):
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
""",
|
||||
strip=True,
|
||||
)
|
||||
@ -314,6 +313,8 @@ class PallasKernel(SIMDKernel):
|
||||
main_name = f"{kernel_name}_main"
|
||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||
with code.indent():
|
||||
# Determine interpret statically based on codegen device
|
||||
interpret_literal = "True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
|
||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||
input_params = [
|
||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||
@ -330,9 +331,7 @@ class PallasKernel(SIMDKernel):
|
||||
# Convert inputs to JAX arrays
|
||||
code.writeline("# Convert Torch -> JAX for inputs")
|
||||
for inp in input_params:
|
||||
code.writeline(
|
||||
f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))"
|
||||
)
|
||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||
|
||||
# Get output spec from PyTorch tensor
|
||||
code.writeline("# Prepare output spec from PyTorch tensor")
|
||||
@ -351,9 +350,11 @@ class PallasKernel(SIMDKernel):
|
||||
)
|
||||
|
||||
# Call pallas
|
||||
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
|
||||
code.writeline("compiled = pl.pallas_call(")
|
||||
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")")
|
||||
|
||||
@ -362,9 +363,7 @@ class PallasKernel(SIMDKernel):
|
||||
|
||||
# Copy result back
|
||||
code.writeline("# Copy result back into the provided torch output tensor")
|
||||
code.writeline(
|
||||
"res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))"
|
||||
)
|
||||
code.writeline("res_t = torch.from_dlpack(res)")
|
||||
code.writeline(f"{output_param}.copy_(res_t)")
|
||||
|
||||
return code.getvalue()
|
||||
|
||||
@ -1958,8 +1958,8 @@ class rocm:
|
||||
contiguous_threshold: int = 16
|
||||
|
||||
|
||||
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
|
||||
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
|
||||
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental)
|
||||
cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp"
|
||||
|
||||
# Backend to use for CUDA codegen either
|
||||
# "triton", "halide" (experimental) or "pallas" (experimental)
|
||||
|
||||
Reference in New Issue
Block a user