Commit Graph

7 Commits

Author SHA1 Message Date
f6a79b2a4a [inductor] Wrap pallas_call in jax.jit (#167441)
My understanding is this is needed for performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167441
Approved by: https://github.com/oulgen
2025-11-10 17:29:56 +00:00
325ec98009 [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-09 01:47:38 +00:00
e342a7509a [pallas backend] add cpu backend and parametrize the tests (#167388)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167388
Approved by: https://github.com/jansel
2025-11-08 23:11:35 +00:00
6392b986e7 Revert "[13/N] Apply ruff UP035 rule (#167048)"
This reverts commit ea44f12bce3eb05eaa9fa34943a3ffae04647fa5.

Reverted https://github.com/pytorch/pytorch/pull/167048 on behalf of https://github.com/donigian due to breaking internal tests D86342860 ([comment](https://github.com/pytorch/pytorch/pull/167048#issuecomment-3505232522))
2025-11-07 22:25:01 +00:00
faba6e205f [pallas backend] use dlpack directly (#167243)
previous version does not work on jax 0.8

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167243
Approved by: https://github.com/yf225, https://github.com/jansel
2025-11-07 05:54:51 +00:00
ea44f12bce [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
f2fbc81c50 [RFC] Add experimental Pallas TorchInductor backend (#166822)
Very simple Pallas TorchInductor backend
Given
```
import torch

def f(x, y):
    return x.sin() + y

torch._inductor.config.cuda_backend="pallas"

x = torch.randn(4).cuda()
y = torch.randn(4).cuda()

compiled = torch.compile(f, backend="inductor", fullgraph=True)
torch.testing.assert_close(compiled(x, y), f(x, y))
```
it outputs
```
import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0):
    tmp0 = in_ptr0[...]
    tmp1 = jnp.sin(tmp0)
    tmp2 = in_ptr1[...]
    tmp3 = tmp1 + tmp2
    out_ptr0[...] = tmp3
def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None):
    # Convert Torch -> JAX for inputs
    in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0))
    in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1))
    # Prepare output spec from PyTorch tensor
    # Map PyTorch dtype to JAX dtype string
    _torch_dtype_to_jax = {
        torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,
        torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,
        torch.uint8: jnp.uint8, torch.bool: jnp.bool_,
    }
    out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype])
    compiled = pl.pallas_call(
        lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs),
        out_shape=out_spec,
        grid=(1,),
    )
    res = compiled(in_ptr0_jax, in_ptr1_jax)
    # Copy result back into the provided torch output tensor
    res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))
    out_ptr0.copy_(res_t)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822
Approved by: https://github.com/jansel
ghstack dependencies: #166976, #166982
2025-11-05 00:52:41 +00:00