mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
[Inductor][CPP] Pass weight dtype explicitly for cpp gemm template (#129221)
**Summary** This PR mainly refactor 2 things: 1. Passing in weight's data type explicitly in `create_micro_gemm` as `input2.dtype`. When registering `CppMicroGemmConfig`, we will reuse `input.dtype` if `input2.dtype` is not explicitly registered. 2. Add an util function to get the output data type and compute data type from input data type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129221 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825, #129048, #129049, #129103, #129220
This commit is contained in:
committed by
PyTorch MergeBot
parent
72fa864098
commit
b6379591a9
@ -441,3 +441,10 @@ def unify_mask_base_type(
|
||||
for var in vars
|
||||
)
|
||||
return new_vars
|
||||
|
||||
|
||||
def get_gemm_template_output_and_compute_dtype(input_dtype):
|
||||
if input_dtype == torch.uint8:
|
||||
return (torch.int32, torch.int32)
|
||||
else:
|
||||
return (torch.float32, torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user