mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
[caffe2] fix type and shape inference for common gradient ops (#35857)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35857 This fixes a lot of common ops for InferBlobShapesAndTypes as well as adds support for testing the inferred shapes and types of gradient ops. Ops: * Concat * Split * LeakyReLU * Relu * Prelu * Gelu * Elu * Sinh, Tanh, Cosh * Abs * ... and a number of other simple element wise ops Test Plan: Added support to hypothesis test to check the shape and type of gradient ops. Enabled it for all the ops I fixed the shape and type inference for. buck test caffe2/caffe2/python/operator_test: Reviewed By: pradeepd24 Differential Revision: D20806284 fbshipit-source-id: 77f796d9ff208e09e871bdbadf9a0a7c196b77f2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c4f56e9685
commit
676fc929b7
@ -233,7 +233,8 @@ class GradientChecker:
|
||||
input_to_check,
|
||||
outputs_with_grads,
|
||||
grad_ops=None,
|
||||
input_device_options=None
|
||||
input_device_options=None,
|
||||
ensure_outputs_are_inferred=False,
|
||||
):
|
||||
"""Checks the operator in a very simple fashion by stacking a sum of
|
||||
squares on the top.
|
||||
@ -250,6 +251,8 @@ class GradientChecker:
|
||||
gradient operator from the gradient registry.
|
||||
input_device_options: an optional mapping from input names to
|
||||
DeviceOptions (to override the default DeviceOption)
|
||||
ensure_outputs_are_inferred: if set will assert that the gradient output
|
||||
shapes matches the inferred shapes
|
||||
Outputs:
|
||||
boolean: True if it passes, False if it does not pass.
|
||||
"""
|
||||
@ -278,7 +281,7 @@ class GradientChecker:
|
||||
grad_name = g_input[input_to_check]
|
||||
loss, grad = self.GetLossAndGrad(
|
||||
op, grad_ops, inputs, op.input, input_to_check, grad_name,
|
||||
outputs_with_grads
|
||||
outputs_with_grads,
|
||||
)
|
||||
grad_estimate = np.zeros_like(inputs[input_to_check])
|
||||
if grad_estimate.shape != grad.shape:
|
||||
@ -286,6 +289,9 @@ class GradientChecker:
|
||||
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
|
||||
grad_estimate.shape, grad.shape))
|
||||
|
||||
if ensure_outputs_are_inferred:
|
||||
self._assertInferTensorChecks(op, grad_ops)
|
||||
|
||||
dims_to_check = inputs[input_to_check].size
|
||||
for current_dim in range(dims_to_check):
|
||||
# Positive gradient
|
||||
@ -322,3 +328,47 @@ class GradientChecker:
|
||||
workspace.ResetWorkspace()
|
||||
workspace.SwitchWorkspace(old_ws_name)
|
||||
return ret, grad, grad_estimate
|
||||
|
||||
def _assertInferTensorChecks(self, op, grad_ops):
|
||||
tmp_net = caffe2_pb2.NetDef()
|
||||
tmp_net.op.extend([op])
|
||||
tmp_net.op.extend(grad_ops)
|
||||
inferred_shapes, inferred_types = workspace.InferShapesAndTypes(
|
||||
[tmp_net],
|
||||
nets_proto=True,
|
||||
)
|
||||
|
||||
outputs = set()
|
||||
for grad_op in grad_ops:
|
||||
outputs.update(grad_op.output)
|
||||
|
||||
for output in outputs:
|
||||
if output not in inferred_shapes:
|
||||
raise Exception(
|
||||
"expected output {} to be inferred".format(output))
|
||||
blob = workspace.FetchBlob(output)
|
||||
correct_shape = list(blob.shape)
|
||||
inferred_shape = list(inferred_shapes[output])
|
||||
if correct_shape != inferred_shape:
|
||||
raise Exception(
|
||||
"Mismatched inferred shape: want({}), got({})".format(
|
||||
correct_shape, inferred_shape))
|
||||
|
||||
if type(blob) is np.ndarray:
|
||||
if blob.dtype == np.dtype('float64'):
|
||||
correct_type = caffe2_pb2.TensorProto.DOUBLE
|
||||
elif blob.dtype == np.dtype('float32'):
|
||||
correct_type = caffe2_pb2.TensorProto.FLOAT
|
||||
elif blob.dtype == np.dtype('int32'):
|
||||
correct_type = caffe2_pb2.TensorProto.INT32
|
||||
elif blob.dtype == np.dtype('int64'):
|
||||
correct_type = caffe2_pb2.TensorProto.INT64
|
||||
else:
|
||||
correct_type = "unknown {}".format(np.dtype)
|
||||
else:
|
||||
correct_type = str(type(blob))
|
||||
inferred_type = inferred_types[output]
|
||||
if correct_type != inferred_type:
|
||||
raise Exception(
|
||||
"Mismatched inferred type: want({}), got({})".format(
|
||||
correct_type, inferred_type))
|
||||
|
||||
Reference in New Issue
Block a user