[bazel] add build for functorch (#101475)

Fixes #101469

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101475
Approved by: https://github.com/ezyang
This commit is contained in:
Sergei Vorobev
2023-05-18 20:29:03 +00:00
committed by PyTorch MergeBot
parent 7ffdd4fedc
commit 2c0d607882
4 changed files with 66 additions and 17 deletions

View File

@ -1697,6 +1697,36 @@ pybind_extension(
],
)
cc_library(
name = "functorch",
hdrs = glob([
"functorch/csrc/dim/*.h",
]),
srcs = glob([
"functorch/csrc/dim/*.cpp",
]),
deps = [
":aten_nvrtc",
":torch_python",
"@pybind11",
],
)
pybind_extension(
name = "functorch/_C",
copts=[
"-DTORCH_EXTENSION_NAME=_C"
],
srcs = [
"functorch/csrc/init_dim_only.cpp",
],
deps = [
":functorch",
":torch_python",
":aten_nvrtc",
],
)
cc_binary(
name = "torch/bin/torch_shm_manager",
srcs = [
@ -1725,7 +1755,7 @@ template_rule(
rules.py_library(
name = "pytorch_py",
visibility = ["//visibility:public"],
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"],
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
deps = [
rules.requirement("future"),
rules.requirement("numpy"),
@ -1738,6 +1768,7 @@ rules.py_library(
],
data = [
":torch/_C.so",
":functorch/_C.so",
":torch/bin/torch_shm_manager",
],
)
@ -1904,7 +1935,8 @@ cc_test(
py_test(
name = "test_bazel",
srcs = ["test/test_bazel.py"],
srcs = ["test/_test_bazel.py"],
main = "test/_test_bazel.py",
deps = [":pytorch_py"],
)

30
test/_test_bazel.py Normal file
View File

@ -0,0 +1,30 @@
# Owner(s): ["module: bazel"]
"""
This test module contains a minimalistic "smoke tests" for the bazel build.
Currently it doesn't use any testing framework (i.e. pytest)
TODO: integrate this into the existing pytorch testing framework.
The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel configurations.
"""
import torch
def test_sum() -> None:
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
def test_simple_compile_eager() -> None:
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo, backend="eager")
# just check that we can run without raising an Exception
assert opt_foo1(torch.randn(10, 10), torch.randn(10, 10)) is not None
test_sum()
test_simple_compile_eager()

View File

@ -1,15 +0,0 @@
# Owner(s): ["module: bazel"]
"""
This test module contains a minimalistic "smoke tests" for the bazel build.
Currently it doesn't use any testing framework (i.e. pytest)
TODO: integrate this into the existing pytorch testing framework.
"""
import torch
def test_sum():
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
test_sum()

View File

@ -1,6 +1,8 @@
__version__ = '{{VERSION}}'
debug = False
cuda = '{{CUDA_VERSION}}'
# TODO: use workspace status to stamp the correct version
git_version = ""
hip = None
# This is a gross monkey-patch hack that depends on the order of imports