mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bazel] add build for functorch (#101475)
Fixes #101469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101475 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ffdd4fedc
commit
2c0d607882
36
BUILD.bazel
36
BUILD.bazel
@ -1697,6 +1697,36 @@ pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "functorch",
|
||||||
|
hdrs = glob([
|
||||||
|
"functorch/csrc/dim/*.h",
|
||||||
|
]),
|
||||||
|
srcs = glob([
|
||||||
|
"functorch/csrc/dim/*.cpp",
|
||||||
|
]),
|
||||||
|
deps = [
|
||||||
|
":aten_nvrtc",
|
||||||
|
":torch_python",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
pybind_extension(
|
||||||
|
name = "functorch/_C",
|
||||||
|
copts=[
|
||||||
|
"-DTORCH_EXTENSION_NAME=_C"
|
||||||
|
],
|
||||||
|
srcs = [
|
||||||
|
"functorch/csrc/init_dim_only.cpp",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":functorch",
|
||||||
|
":torch_python",
|
||||||
|
":aten_nvrtc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "torch/bin/torch_shm_manager",
|
name = "torch/bin/torch_shm_manager",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -1725,7 +1755,7 @@ template_rule(
|
|||||||
rules.py_library(
|
rules.py_library(
|
||||||
name = "pytorch_py",
|
name = "pytorch_py",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"],
|
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
|
||||||
deps = [
|
deps = [
|
||||||
rules.requirement("future"),
|
rules.requirement("future"),
|
||||||
rules.requirement("numpy"),
|
rules.requirement("numpy"),
|
||||||
@ -1738,6 +1768,7 @@ rules.py_library(
|
|||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
":torch/_C.so",
|
":torch/_C.so",
|
||||||
|
":functorch/_C.so",
|
||||||
":torch/bin/torch_shm_manager",
|
":torch/bin/torch_shm_manager",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1904,7 +1935,8 @@ cc_test(
|
|||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_bazel",
|
name = "test_bazel",
|
||||||
srcs = ["test/test_bazel.py"],
|
srcs = ["test/_test_bazel.py"],
|
||||||
|
main = "test/_test_bazel.py",
|
||||||
deps = [":pytorch_py"],
|
deps = [":pytorch_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
30
test/_test_bazel.py
Normal file
30
test/_test_bazel.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Owner(s): ["module: bazel"]
|
||||||
|
|
||||||
|
"""
|
||||||
|
This test module contains a minimalistic "smoke tests" for the bazel build.
|
||||||
|
|
||||||
|
Currently it doesn't use any testing framework (i.e. pytest)
|
||||||
|
TODO: integrate this into the existing pytorch testing framework.
|
||||||
|
|
||||||
|
The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def test_sum() -> None:
|
||||||
|
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
|
||||||
|
|
||||||
|
def test_simple_compile_eager() -> None:
|
||||||
|
|
||||||
|
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
a = torch.sin(x)
|
||||||
|
b = torch.cos(y)
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
opt_foo1 = torch.compile(foo, backend="eager")
|
||||||
|
# just check that we can run without raising an Exception
|
||||||
|
assert opt_foo1(torch.randn(10, 10), torch.randn(10, 10)) is not None
|
||||||
|
|
||||||
|
|
||||||
|
test_sum()
|
||||||
|
test_simple_compile_eager()
|
@ -1,15 +0,0 @@
|
|||||||
# Owner(s): ["module: bazel"]
|
|
||||||
|
|
||||||
"""
|
|
||||||
This test module contains a minimalistic "smoke tests" for the bazel build.
|
|
||||||
|
|
||||||
Currently it doesn't use any testing framework (i.e. pytest)
|
|
||||||
TODO: integrate this into the existing pytorch testing framework.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def test_sum():
|
|
||||||
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
|
|
||||||
|
|
||||||
test_sum()
|
|
@ -1,6 +1,8 @@
|
|||||||
__version__ = '{{VERSION}}'
|
__version__ = '{{VERSION}}'
|
||||||
debug = False
|
debug = False
|
||||||
cuda = '{{CUDA_VERSION}}'
|
cuda = '{{CUDA_VERSION}}'
|
||||||
|
# TODO: use workspace status to stamp the correct version
|
||||||
|
git_version = ""
|
||||||
hip = None
|
hip = None
|
||||||
|
|
||||||
# This is a gross monkey-patch hack that depends on the order of imports
|
# This is a gross monkey-patch hack that depends on the order of imports
|
||||||
|
Reference in New Issue
Block a user