Files
pytorch/torch/utils/_cpp_embed_headers.py
Nikita Shulga dc9b77cc55 [MPS] Support includes in metal objects (#145087)
Useful for code reuse for Metal shader build both for eager mode and MPSInductor, but it requires one to implement `_cpp_embed_headers` tool that, as name suggests, would preprocess and embeds the for shader to be used in dynamic compilation.
Test using:
 -  `TestMetalLibrary.test_metal_include`
 - Moving `i0`/`i1` implementation to `c10/util/metal_special_math.h` and call it from `SpecialOps.metal` shader, which now looks much more compact:
 ```metal
template <typename T, typename Tout = T>
void kernel
i0(constant T* input,
   device Tout* output,
   uint index [[thread_position_in_grid]]) {
  output[index] = c10::i0(static_cast<Tout>(input[index]));
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145087
Approved by: https://github.com/dcci
ghstack dependencies: #145023
2025-01-18 05:35:22 +00:00

53 lines
1.5 KiB
Python

from pathlib import Path
from re import match as _match
from typing import List, Optional, Sequence, Set, Union
def read_file(fname: Union[Path, str]) -> List[str]:
with open(fname, encoding="utf-8") as f:
return f.readlines()
def _embed_headers(
content: List[str], include_dirs: List[Path], processed_files: Set[str]
) -> str:
for line_idx, cur_line in enumerate(content):
m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line)
if m is None:
continue
for include_dir in include_dirs:
path = include_dir / m[1]
if not path.exists():
continue
if str(path) in processed_files:
content[line_idx] = ""
continue
processed_files.add(str(path))
content[line_idx] = _embed_headers(
read_file(path), include_dirs, processed_files
)
break
return "".join(content)
def embed_headers(
fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None
) -> str:
if include_dirs is None:
include_dirs = [Path(__file__).parent.parent.parent]
elif isinstance(include_dirs, str):
include_dirs = [Path(include_dirs)]
else:
include_dirs = [Path(x) for x in include_dirs]
return _embed_headers(read_file(fname), include_dirs, {fname})
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage:\n {sys.argv[0]} filename")
sys.exit(1)
print(embed_headers(sys.argv[1]))