mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is the continuation to #90134 and hopefully the final PR in this series. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271 Approved by: https://github.com/kit1980
112 lines
4.5 KiB
Python
112 lines
4.5 KiB
Python
import copy
|
|
import os
|
|
|
|
from collections import OrderedDict
|
|
|
|
import yaml
|
|
from torchgen.code_template import CodeTemplate
|
|
from yaml.constructor import ConstructorError
|
|
from yaml.nodes import MappingNode
|
|
|
|
try:
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Loader # type: ignore[misc]
|
|
|
|
# https://gist.github.com/pypt/94d747fe5180851196eb
|
|
class UniqueKeyLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
if not isinstance(node, MappingNode):
|
|
raise ConstructorError(
|
|
None,
|
|
None,
|
|
"expected a mapping node, but found %s" % node.id,
|
|
node.start_mark,
|
|
)
|
|
mapping = {}
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
try:
|
|
hash(key)
|
|
except TypeError as e:
|
|
raise ConstructorError(
|
|
"while constructing a mapping",
|
|
node.start_mark,
|
|
"found unacceptable key ",
|
|
key_node.start_mark,
|
|
) from e
|
|
# check for duplicate keys
|
|
if key in mapping:
|
|
raise ConstructorError(
|
|
"while constructing a mapping",
|
|
node.start_mark,
|
|
"found duplicate key",
|
|
key_node.start_mark,
|
|
)
|
|
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
|
|
mapping[key] = value
|
|
return mapping
|
|
|
|
|
|
class GLSLGenerator(object):
|
|
standard_header = """
|
|
#version 450 core
|
|
#define PRECISION $precision
|
|
#define FORMAT $format
|
|
|
|
"""
|
|
|
|
def __init__(self): # type: ignore[no-untyped-def]
|
|
self.ops_template_params = {}
|
|
|
|
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
|
|
all_template_params = OrderedDict()
|
|
with open(parameters_yaml_file, "r") as f:
|
|
contents = yaml.load(f, Loader=UniqueKeyLoader)
|
|
for key in contents:
|
|
all_template_params[key] = contents[key]
|
|
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
|
|
|
|
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
|
|
for op in all_template_params:
|
|
if op in self.ops_template_params:
|
|
raise KeyError(f"{op} params file has already been parsed")
|
|
op_params_default_vals = all_template_params[op][
|
|
"parameter_names_with_default_values"
|
|
]
|
|
template_params_set = set(op_params_default_vals.keys())
|
|
self.ops_template_params[op] = []
|
|
self.ops_template_params[op].append(op_params_default_vals)
|
|
op_template_params_values = all_template_params[op]["parameter_values"]
|
|
for param_vals in op_template_params_values:
|
|
param_vals_set = set(param_vals.keys())
|
|
invalid_keys = param_vals_set - template_params_set
|
|
if (len(invalid_keys)) > 0:
|
|
raise KeyError(f"Invalid keys {invalid_keys} are found")
|
|
param_vals_copy = copy.deepcopy(op_params_default_vals)
|
|
for key in param_vals:
|
|
param_vals_copy[key] = param_vals[key]
|
|
self.ops_template_params[op].append(param_vals_copy)
|
|
|
|
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
|
|
glsl_template_name = os.path.basename(glsl_template_in)
|
|
op_name, extension_name = glsl_template_name.split(".")
|
|
if extension_name != "glslt":
|
|
raise TypeError(f"invalid file type for glsl template {extension_name}")
|
|
if op_name not in self.ops_template_params:
|
|
raise KeyError(f"{op_name} params have not been populated")
|
|
code_template = CodeTemplate.from_file(glsl_template_in)
|
|
for template_params in self.ops_template_params[op_name]:
|
|
content = GLSLGenerator.standard_header
|
|
param_vals_string = "x".join([str(i) for i in template_params.values()])
|
|
output_file_name = op_name + "_" + param_vals_string + ".glsl"
|
|
content += code_template.substitute(template_params)
|
|
output_file = os.path.join(out_dir, output_file_name)
|
|
with open(output_file, "w") as f:
|
|
f.write(content)
|
|
|
|
|
|
# Remove this
|
|
if __name__ == "__main__":
|
|
pass
|