Resubmit: Set the correct engine name for position weighted pooling when fp16 is used for training

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13768

Reviewed By: xianjiec

Differential Revision: D12996103

fbshipit-source-id: 5ca4cda4210f68ece2b5d6eced8cf52ee91fb36f
This commit is contained in:
Jiyan Yang
2018-11-27 14:49:28 -08:00
committed by Facebook Github Bot
parent ae1b37650c
commit 0199d59d3a

View File

@ -173,6 +173,8 @@ class SparseLookup(ModelLayer):
"Train version {} is not currently supported".format(trainer_version)
)
self.trainer_version = trainer_version
return default_weight_init
def _gather_wrapper(self, net, version, in_indices, out):
@ -215,11 +217,22 @@ class SparseLookup(ModelLayer):
if version in ['fp32', 'fp16']:
# SparseLengths* Ops will accept either fp16 or fp32 embedding
# matrix and output fp32 pooled embedding
net.__getattr__(layer_name)(
op_input,
self.output_schema.field_blobs(),
grad_on_weights=grad_on_weights,
)
# A special case here is that we need FP16 engine for
# SparseLengthsWeightedSum when FP16 embeedings are used for
# correct backward updates
if reducer == "WeightedSum" and version == "fp16":
net.SparseLengthsWeightedSum(
op_input,
self.output_schema.field_blobs(),
grad_on_weights=grad_on_weights,
engine='FP16',
)
else:
net.__getattr__(layer_name)(
op_input,
self.output_schema.field_blobs(),
grad_on_weights=grad_on_weights,
)
elif version == 'uint8rowwise':
op_input.insert(len(op_input), self.scale_bias)
net.__getattr__(layer_name + '8BitsRowwise')(
@ -345,6 +358,17 @@ class SparseLookup(ModelLayer):
raise "Only Sum, Mean, None are supported for IdScoreList input." +\
"Trying to create with {}".format(self.reducer)
def _add_ops(self, net, version='fp32'):
if _is_id_list(self.input_record):
self._add_ops_id_list(net, version=version)
elif _is_id_score_list(self.input_record):
self._add_ops_id_score_list(net, version=version)
else:
raise "Unsupported input type {0}".format(self.input_record)
def add_train_ops(self, net):
self._add_ops(net, self.trainer_version)
def add_ops(self, net):
cur_scope = get_current_scope()
version = get_sparse_lookup_predictor_version(
@ -357,9 +381,4 @@ class SparseLookup(ModelLayer):
'fused_uint8rowwise'}:
version = 'fp32'
if _is_id_list(self.input_record):
self._add_ops_id_list(net, version=version)
elif _is_id_score_list(self.input_record):
self._add_ops_id_score_list(net, version=version)
else:
raise "Unsupported input type {0}".format(self.input_record)
self._add_ops(net, version)