[Inductor][CPP] Pass weight dtype explicitly for cpp gemm template (#129221)

**Summary**
This PR mainly refactor 2 things:

1. Passing in weight's data type explicitly in `create_micro_gemm` as `input2.dtype`. When registering `CppMicroGemmConfig`, we will reuse `input.dtype` if `input2.dtype` is not explicitly registered.
2. Add an util function to get the output data type and compute data type from input data type.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129221
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825, #129048, #129049, #129103, #129220
This commit is contained in:
leslie-fang-intel
2024-06-27 21:41:47 -07:00
committed by PyTorch MergeBot
parent 72fa864098
commit b6379591a9
4 changed files with 50 additions and 14 deletions

View File

@ -441,3 +441,10 @@ def unify_mask_base_type(
for var in vars
)
return new_vars
def get_gemm_template_output_and_compute_dtype(input_dtype):
if input_dtype == torch.uint8:
return (torch.int32, torch.int32)
else:
return (torch.float32, torch.float32)