mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bazel] add build for functorch (#101475)
Fixes #101469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101475 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ffdd4fedc
commit
2c0d607882
36
BUILD.bazel
36
BUILD.bazel
@ -1697,6 +1697,36 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "functorch",
|
||||
hdrs = glob([
|
||||
"functorch/csrc/dim/*.h",
|
||||
]),
|
||||
srcs = glob([
|
||||
"functorch/csrc/dim/*.cpp",
|
||||
]),
|
||||
deps = [
|
||||
":aten_nvrtc",
|
||||
":torch_python",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "functorch/_C",
|
||||
copts=[
|
||||
"-DTORCH_EXTENSION_NAME=_C"
|
||||
],
|
||||
srcs = [
|
||||
"functorch/csrc/init_dim_only.cpp",
|
||||
],
|
||||
deps = [
|
||||
":functorch",
|
||||
":torch_python",
|
||||
":aten_nvrtc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "torch/bin/torch_shm_manager",
|
||||
srcs = [
|
||||
@ -1725,7 +1755,7 @@ template_rule(
|
||||
rules.py_library(
|
||||
name = "pytorch_py",
|
||||
visibility = ["//visibility:public"],
|
||||
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"],
|
||||
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
|
||||
deps = [
|
||||
rules.requirement("future"),
|
||||
rules.requirement("numpy"),
|
||||
@ -1738,6 +1768,7 @@ rules.py_library(
|
||||
],
|
||||
data = [
|
||||
":torch/_C.so",
|
||||
":functorch/_C.so",
|
||||
":torch/bin/torch_shm_manager",
|
||||
],
|
||||
)
|
||||
@ -1904,7 +1935,8 @@ cc_test(
|
||||
|
||||
py_test(
|
||||
name = "test_bazel",
|
||||
srcs = ["test/test_bazel.py"],
|
||||
srcs = ["test/_test_bazel.py"],
|
||||
main = "test/_test_bazel.py",
|
||||
deps = [":pytorch_py"],
|
||||
)
|
||||
|
||||
|
30
test/_test_bazel.py
Normal file
30
test/_test_bazel.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Owner(s): ["module: bazel"]
|
||||
|
||||
"""
|
||||
This test module contains a minimalistic "smoke tests" for the bazel build.
|
||||
|
||||
Currently it doesn't use any testing framework (i.e. pytest)
|
||||
TODO: integrate this into the existing pytorch testing framework.
|
||||
|
||||
The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel configurations.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
def test_sum() -> None:
|
||||
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
|
||||
|
||||
def test_simple_compile_eager() -> None:
|
||||
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return a + b
|
||||
|
||||
opt_foo1 = torch.compile(foo, backend="eager")
|
||||
# just check that we can run without raising an Exception
|
||||
assert opt_foo1(torch.randn(10, 10), torch.randn(10, 10)) is not None
|
||||
|
||||
|
||||
test_sum()
|
||||
test_simple_compile_eager()
|
@ -1,15 +0,0 @@
|
||||
# Owner(s): ["module: bazel"]
|
||||
|
||||
"""
|
||||
This test module contains a minimalistic "smoke tests" for the bazel build.
|
||||
|
||||
Currently it doesn't use any testing framework (i.e. pytest)
|
||||
TODO: integrate this into the existing pytorch testing framework.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
def test_sum():
|
||||
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
|
||||
|
||||
test_sum()
|
@ -1,6 +1,8 @@
|
||||
__version__ = '{{VERSION}}'
|
||||
debug = False
|
||||
cuda = '{{CUDA_VERSION}}'
|
||||
# TODO: use workspace status to stamp the correct version
|
||||
git_version = ""
|
||||
hip = None
|
||||
|
||||
# This is a gross monkey-patch hack that depends on the order of imports
|
||||
|
Reference in New Issue
Block a user