[caffe2] fix type and shape inference for common gradient ops (#35857)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35857

This fixes a lot of common ops for InferBlobShapesAndTypes as well as adds support for testing the inferred shapes and types of gradient ops.

Ops:
* Concat
* Split
* LeakyReLU
* Relu
* Prelu
* Gelu
* Elu
* Sinh, Tanh, Cosh
* Abs
* ... and a number of other simple element wise ops

Test Plan:
Added support to hypothesis test to check the shape and type of gradient ops.

Enabled it for all the ops I fixed the shape and type inference for.

  buck test caffe2/caffe2/python/operator_test:

Reviewed By: pradeepd24

Differential Revision: D20806284

fbshipit-source-id: 77f796d9ff208e09e871bdbadf9a0a7c196b77f2
This commit is contained in:
Tristan Rice
2020-04-02 11:12:27 -07:00
committed by Facebook GitHub Bot
parent c4f56e9685
commit 676fc929b7
18 changed files with 270 additions and 107 deletions

View File

@ -233,7 +233,8 @@ class GradientChecker:
input_to_check,
outputs_with_grads,
grad_ops=None,
input_device_options=None
input_device_options=None,
ensure_outputs_are_inferred=False,
):
"""Checks the operator in a very simple fashion by stacking a sum of
squares on the top.
@ -250,6 +251,8 @@ class GradientChecker:
gradient operator from the gradient registry.
input_device_options: an optional mapping from input names to
DeviceOptions (to override the default DeviceOption)
ensure_outputs_are_inferred: if set will assert that the gradient output
shapes matches the inferred shapes
Outputs:
boolean: True if it passes, False if it does not pass.
"""
@ -278,7 +281,7 @@ class GradientChecker:
grad_name = g_input[input_to_check]
loss, grad = self.GetLossAndGrad(
op, grad_ops, inputs, op.input, input_to_check, grad_name,
outputs_with_grads
outputs_with_grads,
)
grad_estimate = np.zeros_like(inputs[input_to_check])
if grad_estimate.shape != grad.shape:
@ -286,6 +289,9 @@ class GradientChecker:
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
grad_estimate.shape, grad.shape))
if ensure_outputs_are_inferred:
self._assertInferTensorChecks(op, grad_ops)
dims_to_check = inputs[input_to_check].size
for current_dim in range(dims_to_check):
# Positive gradient
@ -322,3 +328,47 @@ class GradientChecker:
workspace.ResetWorkspace()
workspace.SwitchWorkspace(old_ws_name)
return ret, grad, grad_estimate
def _assertInferTensorChecks(self, op, grad_ops):
tmp_net = caffe2_pb2.NetDef()
tmp_net.op.extend([op])
tmp_net.op.extend(grad_ops)
inferred_shapes, inferred_types = workspace.InferShapesAndTypes(
[tmp_net],
nets_proto=True,
)
outputs = set()
for grad_op in grad_ops:
outputs.update(grad_op.output)
for output in outputs:
if output not in inferred_shapes:
raise Exception(
"expected output {} to be inferred".format(output))
blob = workspace.FetchBlob(output)
correct_shape = list(blob.shape)
inferred_shape = list(inferred_shapes[output])
if correct_shape != inferred_shape:
raise Exception(
"Mismatched inferred shape: want({}), got({})".format(
correct_shape, inferred_shape))
if type(blob) is np.ndarray:
if blob.dtype == np.dtype('float64'):
correct_type = caffe2_pb2.TensorProto.DOUBLE
elif blob.dtype == np.dtype('float32'):
correct_type = caffe2_pb2.TensorProto.FLOAT
elif blob.dtype == np.dtype('int32'):
correct_type = caffe2_pb2.TensorProto.INT32
elif blob.dtype == np.dtype('int64'):
correct_type = caffe2_pb2.TensorProto.INT64
else:
correct_type = "unknown {}".format(np.dtype)
else:
correct_type = str(type(blob))
inferred_type = inferred_types[output]
if correct_type != inferred_type:
raise Exception(
"Mismatched inferred type: want({}), got({})".format(
correct_type, inferred_type))