Commit Graph

4 Commits

Author SHA1 Message Date
541297daae [Build] Allow metal shaders to include ATen headers (#156256)
No-op change that will be used later to share structs between CPU and Metal
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156256
Approved by: https://github.com/dcci
2025-06-18 01:03:25 +00:00
e3839bd603 [BE] Strip #pragma once when embedding the headers (#146871)
This eliminates compiler warning, for example when compiling Metal shader with embedded headers
```
 with program_source:6:9: warning: #pragma once in main file [-Wpragma-once-outside-header]
#pragma once
        ^
program_source:81:9: warning: #pragma once in main file [-Wpragma-once-outside-header]
#pragma once
        ^
program_source:588:9: warning: #pragma once in main file [-Wpragma-once-outside-header]
#pragma once
        ^
program_source:719:9: warning: #pragma once in main file [-Wpragma-once-outside-header]
#pragma once
        ^
program_source:829:29: error: use of undeclared identifier 'r0_2'
        auto tmp8 = in_ptr2[r0_2 + 768*x0];
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146871
Approved by: https://github.com/dcci
2025-02-11 16:49:00 +00:00
7178b827d7 PEP585: Missed conversions (#145342)
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
2025-01-29 05:24:36 +00:00
dc9b77cc55 [MPS] Support includes in metal objects (#145087)
Useful for code reuse for Metal shader build both for eager mode and MPSInductor, but it requires one to implement `_cpp_embed_headers` tool that, as name suggests, would preprocess and embeds the for shader to be used in dynamic compilation.
Test using:
 -  `TestMetalLibrary.test_metal_include`
 - Moving `i0`/`i1` implementation to `c10/util/metal_special_math.h` and call it from `SpecialOps.metal` shader, which now looks much more compact:
 ```metal
template <typename T, typename Tout = T>
void kernel
i0(constant T* input,
   device Tout* output,
   uint index [[thread_position_in_grid]]) {
  output[index] = c10::i0(static_cast<Tout>(input[index]));
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145087
Approved by: https://github.com/dcci
ghstack dependencies: #145023
2025-01-18 05:35:22 +00:00