Files
pytorch/tools/autograd/utils.py
anderspapitto fcd9af8a25 changes to support ATen code generation inside fbcode (#8397)
* Back out "Back out "Add support for generating ATen files during fbcode build""

Original commit changeset: 7b8de22d1613

I'm re-sending this diff exactly as it was approved and
committed. Fixes to support @mode/opt will be sent separately for ease
of review.

* Enable building //caffe2:torch with @mode/opt

In @mode/opt, python runs out of a PAR, which breaks a lot of
assumptions in the code about where templates/ folders live relative
to __file__. Rather than introduce hacks with parutil, I simply turn
template_path into a parameter for all the relevant functions and
thread it through from the top level.
2018-06-12 14:57:29 -07:00

66 lines
2.0 KiB
Python

import re
import os
from .nested_dict import nested_dict
__all__ = [
'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict',
'split_name_params', 'write',
]
try:
from src.ATen.code_template import CodeTemplate
except ImportError:
from tools.shared.module_loader import import_module
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
try:
# use faster C loader if available
from yaml import CLoader as YamlLoader
except ImportError:
from yaml import Loader as YamlLoader
GENERATED_COMMENT = CodeTemplate(
"@" + "generated from tools/autograd/templates/${filename}")
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
# occurence of a parameter in the derivative formula
IDENT_REGEX = r'(^|\W){}($|\W)'
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
def split_name_params(prototype):
name, params = re.match('(\w+)\((.*)\)', prototype).groups()
return name, params.split(', ')
# When tracing, we record inplace operations as out-of-place operations,
# because we don't have a story for side effects in the IR yet.
#
# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace!
# TODO: Do something more robust
def uninplace_api_name(api_name):
if api_name.endswith('_') and not api_name.endswith('__'):
api_name = api_name[:-1]
return api_name
def write(dirname, name, template, env):
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=name)
path = os.path.join(dirname, name)
# See Note [Unchanging results for ninja]
try:
with open(path, 'r') as f:
old_val = f.read()
except IOError:
old_val = None
new_val = template.substitute(env)
if old_val != new_val:
with open(path, 'w') as f:
print("Writing {}".format(path))
f.write(new_val)
else:
print("Skipped writing {}".format(path))