mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Useful for code reuse for Metal shader build both for eager mode and MPSInductor, but it requires one to implement `_cpp_embed_headers` tool that, as name suggests, would preprocess and embeds the for shader to be used in dynamic compilation. Test using: - `TestMetalLibrary.test_metal_include` - Moving `i0`/`i1` implementation to `c10/util/metal_special_math.h` and call it from `SpecialOps.metal` shader, which now looks much more compact: ```metal template <typename T, typename Tout = T> void kernel i0(constant T* input, device Tout* output, uint index [[thread_position_in_grid]]) { output[index] = c10::i0(static_cast<Tout>(input[index])); } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145087 Approved by: https://github.com/dcci ghstack dependencies: #145023
53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
from pathlib import Path
|
|
from re import match as _match
|
|
from typing import List, Optional, Sequence, Set, Union
|
|
|
|
|
|
def read_file(fname: Union[Path, str]) -> List[str]:
|
|
with open(fname, encoding="utf-8") as f:
|
|
return f.readlines()
|
|
|
|
|
|
def _embed_headers(
|
|
content: List[str], include_dirs: List[Path], processed_files: Set[str]
|
|
) -> str:
|
|
for line_idx, cur_line in enumerate(content):
|
|
m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line)
|
|
if m is None:
|
|
continue
|
|
for include_dir in include_dirs:
|
|
path = include_dir / m[1]
|
|
if not path.exists():
|
|
continue
|
|
if str(path) in processed_files:
|
|
content[line_idx] = ""
|
|
continue
|
|
processed_files.add(str(path))
|
|
content[line_idx] = _embed_headers(
|
|
read_file(path), include_dirs, processed_files
|
|
)
|
|
break
|
|
return "".join(content)
|
|
|
|
|
|
def embed_headers(
|
|
fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None
|
|
) -> str:
|
|
if include_dirs is None:
|
|
include_dirs = [Path(__file__).parent.parent.parent]
|
|
elif isinstance(include_dirs, str):
|
|
include_dirs = [Path(include_dirs)]
|
|
else:
|
|
include_dirs = [Path(x) for x in include_dirs]
|
|
|
|
return _embed_headers(read_file(fname), include_dirs, {fname})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage:\n {sys.argv[0]} filename")
|
|
sys.exit(1)
|
|
print(embed_headers(sys.argv[1]))
|