mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Resubmit: Set the correct engine name for position weighted pooling when fp16 is used for training
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13768 Reviewed By: xianjiec Differential Revision: D12996103 fbshipit-source-id: 5ca4cda4210f68ece2b5d6eced8cf52ee91fb36f
This commit is contained in:
committed by
Facebook Github Bot
parent
ae1b37650c
commit
0199d59d3a
@ -173,6 +173,8 @@ class SparseLookup(ModelLayer):
|
||||
"Train version {} is not currently supported".format(trainer_version)
|
||||
)
|
||||
|
||||
self.trainer_version = trainer_version
|
||||
|
||||
return default_weight_init
|
||||
|
||||
def _gather_wrapper(self, net, version, in_indices, out):
|
||||
@ -215,11 +217,22 @@ class SparseLookup(ModelLayer):
|
||||
if version in ['fp32', 'fp16']:
|
||||
# SparseLengths* Ops will accept either fp16 or fp32 embedding
|
||||
# matrix and output fp32 pooled embedding
|
||||
net.__getattr__(layer_name)(
|
||||
op_input,
|
||||
self.output_schema.field_blobs(),
|
||||
grad_on_weights=grad_on_weights,
|
||||
)
|
||||
# A special case here is that we need FP16 engine for
|
||||
# SparseLengthsWeightedSum when FP16 embeedings are used for
|
||||
# correct backward updates
|
||||
if reducer == "WeightedSum" and version == "fp16":
|
||||
net.SparseLengthsWeightedSum(
|
||||
op_input,
|
||||
self.output_schema.field_blobs(),
|
||||
grad_on_weights=grad_on_weights,
|
||||
engine='FP16',
|
||||
)
|
||||
else:
|
||||
net.__getattr__(layer_name)(
|
||||
op_input,
|
||||
self.output_schema.field_blobs(),
|
||||
grad_on_weights=grad_on_weights,
|
||||
)
|
||||
elif version == 'uint8rowwise':
|
||||
op_input.insert(len(op_input), self.scale_bias)
|
||||
net.__getattr__(layer_name + '8BitsRowwise')(
|
||||
@ -345,6 +358,17 @@ class SparseLookup(ModelLayer):
|
||||
raise "Only Sum, Mean, None are supported for IdScoreList input." +\
|
||||
"Trying to create with {}".format(self.reducer)
|
||||
|
||||
def _add_ops(self, net, version='fp32'):
|
||||
if _is_id_list(self.input_record):
|
||||
self._add_ops_id_list(net, version=version)
|
||||
elif _is_id_score_list(self.input_record):
|
||||
self._add_ops_id_score_list(net, version=version)
|
||||
else:
|
||||
raise "Unsupported input type {0}".format(self.input_record)
|
||||
|
||||
def add_train_ops(self, net):
|
||||
self._add_ops(net, self.trainer_version)
|
||||
|
||||
def add_ops(self, net):
|
||||
cur_scope = get_current_scope()
|
||||
version = get_sparse_lookup_predictor_version(
|
||||
@ -357,9 +381,4 @@ class SparseLookup(ModelLayer):
|
||||
'fused_uint8rowwise'}:
|
||||
version = 'fp32'
|
||||
|
||||
if _is_id_list(self.input_record):
|
||||
self._add_ops_id_list(net, version=version)
|
||||
elif _is_id_score_list(self.input_record):
|
||||
self._add_ops_id_score_list(net, version=version)
|
||||
else:
|
||||
raise "Unsupported input type {0}".format(self.input_record)
|
||||
self._add_ops(net, version)
|
||||
|
Reference in New Issue
Block a user