|
|
f2fbc81c50
|
[RFC] Add experimental Pallas TorchInductor backend (#166822)
Very simple Pallas TorchInductor backend
Given
```
import torch
def f(x, y):
return x.sin() + y
torch._inductor.config.cuda_backend="pallas"
x = torch.randn(4).cuda()
y = torch.randn(4).cuda()
compiled = torch.compile(f, backend="inductor", fullgraph=True)
torch.testing.assert_close(compiled(x, y), f(x, y))
```
it outputs
```
import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0):
tmp0 = in_ptr0[...]
tmp1 = jnp.sin(tmp0)
tmp2 = in_ptr1[...]
tmp3 = tmp1 + tmp2
out_ptr0[...] = tmp3
def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None):
# Convert Torch -> JAX for inputs
in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0))
in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1))
# Prepare output spec from PyTorch tensor
# Map PyTorch dtype to JAX dtype string
_torch_dtype_to_jax = {
torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,
torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,
torch.uint8: jnp.uint8, torch.bool: jnp.bool_,
}
out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype])
compiled = pl.pallas_call(
lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs),
out_shape=out_spec,
grid=(1,),
)
res = compiled(in_ptr0_jax, in_ptr1_jax)
# Copy result back into the provided torch output tensor
res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))
out_ptr0.copy_(res_t)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822
Approved by: https://github.com/jansel
ghstack dependencies: #166976, #166982
|
2025-11-05 00:52:41 +00:00 |
|