mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
* Adding instance weight to batch distill loss as title * add bfloat 16-31 added bfloat 16-31 and their respective unit tests * [CUDA9] Upgrade - fbcode CUDA9 upgrade diff D5654023 has been out for a while thanks to Pieter. But with time growing it's becoming quite hard to rebase, because of the symlinks and auto-generated build/config files in tp2. Break D5654023 into two diffs, one touching tp2 config files, and another one touching fbcode TARGETS file (adding nvcc flag). These two should be a bit easier to rebase (for detailed procedure see "Test Plan"). This diff can only be committed if: 1. CUDA 9 rpm is rolled out fleet-wide (TBD) 2. NVidia driver 390.40 is rolled out fleet-wide (done) 3. Upgrade CUDA 9.1, cudnn 7.1, nccl 2.1 (done) 4. Make sure all dependents are built (done) 5. Test all C2 operators, PyTorch (see test plan) * Share intermediate int32 buffer across Conv ops Adding a known type * [C2 fix] infer function for ensure_cpu_output_op this is adding the missing device funtion for ensure_cpu_output_op * [int8] Add blob serializer/deserializer for Int8TensorCPU To export to logfiledb * [nomnigraph] Add try catch block to optimization passes in predictor This will catch failures that happen in the optimization pass. * Caffe2: avoid static initialization order fiasco for CAFFE_ENFORCE CAFFE_ENFORCE uses strack trace fetcher. Which is currently a global static variable. If at static initialization time CAFFE_ENFORCE is used, this is a SIOF. Recently CAFFE_ENFORCE was added into init functions registration, so we started to see this. Meyers singleton is going to provide safety here. If stacktrace fetcher was not registered yet, it will just use a dummy one. * NUMA support in SparseNN CPU benchmark Adding support for NUMA in SparseNN CPU benchmark * [mobile-roofline] Add logging needed for roofline model This should be all that's needed * Let the operators using the same input if the operators are not chained or else, we have to change the input data dims * fix null-pointer-use UBSAN errors in in reshape_op.h * revert previous fix on input blob name as title * Adding flag to let MineHardNegative automatically extract single value from dict Model exporter requires the output of the model to be a struct. This makes it convenient to use those models directly in MineHardNegative by allow automatic extraction of the single element of dict, which is a common use case. * Reverting change that broke internal tests back to OSS compatible state
262 lines
9.9 KiB
Python
262 lines
9.9 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import functools
|
|
|
|
import hypothesis
|
|
from hypothesis import given, settings, HealthCheck
|
|
import hypothesis.strategies as st
|
|
import numpy as np
|
|
|
|
from caffe2.python import core
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
from caffe2.python.operator_test.adagrad_test_helper import (
|
|
ref_adagrad, adagrad_sparse_test_helper
|
|
)
|
|
|
|
|
|
class TestAdagrad(hu.HypothesisTestCase):
|
|
@staticmethod
|
|
def ref_row_wise_adagrad(param_in, mom_in, grad, lr, epsilon):
|
|
mom_out = mom_in + np.mean(np.square(grad))
|
|
grad_adj = lr * grad / (np.sqrt(mom_out) + epsilon)
|
|
param_out = param_in + grad_adj
|
|
return (param_out, mom_out)
|
|
|
|
@given(inputs=hu.tensors(n=3),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
**hu.gcs)
|
|
def test_adagrad(self, inputs, lr, epsilon, gc, dc):
|
|
param, momentum, grad = inputs
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
op = core.CreateOperator(
|
|
"Adagrad",
|
|
["param", "momentum", "grad", "lr"],
|
|
["param", "momentum"],
|
|
epsilon=epsilon,
|
|
device_option=gc,
|
|
)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op,
|
|
[param, momentum, grad, lr],
|
|
functools.partial(ref_adagrad, epsilon=epsilon))
|
|
|
|
@given(inputs=hu.tensors(n=3),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
**hu.gcs_cpu_only)
|
|
def test_adagrad_output_effective_lr(self, inputs, lr, epsilon, gc, dc):
|
|
param, momentum, grad = inputs
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
op = core.CreateOperator(
|
|
"Adagrad",
|
|
["param", "momentum", "grad", "lr"],
|
|
["param", "momentum", "effective_lr"],
|
|
epsilon=epsilon,
|
|
device_option=gc,
|
|
)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op,
|
|
[param, momentum, grad, lr],
|
|
functools.partial(ref_adagrad, epsilon=epsilon,
|
|
output_effective_lr=True))
|
|
|
|
@given(inputs=hu.tensors(n=3),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
**hu.gcs_cpu_only)
|
|
def test_adagrad_output_effective_lr_and_update(
|
|
self, inputs, lr, epsilon, gc, dc):
|
|
param, momentum, grad = inputs
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
op = core.CreateOperator(
|
|
"Adagrad",
|
|
["param", "momentum", "grad", "lr"],
|
|
["param", "momentum", "effective_lr", "update"],
|
|
epsilon=epsilon,
|
|
device_option=gc,
|
|
)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op,
|
|
[param, momentum, grad, lr],
|
|
functools.partial(ref_adagrad, epsilon=epsilon,
|
|
output_effective_lr_and_update=True))
|
|
|
|
# Suppress filter_too_much health check.
|
|
# Likely caused by `assume` call falling through too often.
|
|
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
|
@given(inputs=hu.tensors(n=3),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
**hu.gcs)
|
|
def test_sparse_adagrad(self, inputs, lr, epsilon, gc, dc):
|
|
return adagrad_sparse_test_helper(self, inputs, lr, epsilon,
|
|
None, ref_adagrad, gc, dc)
|
|
|
|
@given(inputs=hu.tensors(n=2),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
data_strategy=st.data(),
|
|
**hu.gcs)
|
|
def test_sparse_adagrad_empty(self, inputs, lr, epsilon,
|
|
data_strategy, gc, dc):
|
|
param, momentum = inputs
|
|
momentum = np.abs(momentum)
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
|
|
indices = np.empty(shape=(0,), dtype=np.int64)
|
|
|
|
hypothesis.note('indices.shape: %s' % str(indices.shape))
|
|
|
|
op = core.CreateOperator(
|
|
"SparseAdagrad",
|
|
["param", "momentum", "indices", "grad", "lr"],
|
|
["param", "momentum"],
|
|
epsilon=epsilon,
|
|
device_option=gc)
|
|
|
|
def ref_sparse(param, momentum, indices, grad, lr):
|
|
param_out = np.copy(param)
|
|
momentum_out = np.copy(momentum)
|
|
return (param_out, momentum_out)
|
|
|
|
ref_using_fp16_values = [False]
|
|
if dc == hu.gpu_do:
|
|
ref_using_fp16_values.append(True)
|
|
|
|
for ref_using_fp16 in ref_using_fp16_values:
|
|
if(ref_using_fp16):
|
|
print('test_sparse_adagrad_empty with half precision embedding')
|
|
momentum_i = momentum.astype(np.float16)
|
|
param_i = param.astype(np.float16)
|
|
else:
|
|
print('test_sparse_adagrad_empty with full precision embedding')
|
|
momentum_i = momentum.astype(np.float32)
|
|
param_i = param.astype(np.float32)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op, [param_i, momentum_i, indices, grad, lr], ref_sparse
|
|
)
|
|
|
|
# Suppress filter_too_much health check.
|
|
# Likely caused by `assume` call falling through too often.
|
|
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
|
@given(inputs=hu.tensors(n=2),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
data_strategy=st.data(),
|
|
**hu.gcs)
|
|
def test_row_wise_sparse_adagrad(self, inputs, lr, epsilon,
|
|
data_strategy, gc, dc):
|
|
param, grad = inputs
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
# Create a 1D row-wise average sum of squared gradients tensor.
|
|
momentum = data_strategy.draw(
|
|
hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
|
|
elements=hu.elements_of_type(dtype=np.float32))
|
|
)
|
|
momentum = np.abs(momentum)
|
|
|
|
# Create an indexing array containing values which index into grad
|
|
indices = data_strategy.draw(
|
|
hu.tensor(dtype=np.int64,
|
|
elements=st.sampled_from(np.arange(grad.shape[0]))),
|
|
)
|
|
|
|
# Note that unlike SparseAdagrad, RowWiseSparseAdagrad uses a moment
|
|
# tensor that is strictly 1-dimensional and equal in length to the
|
|
# first dimension of the parameters, so indices must also be
|
|
# 1-dimensional.
|
|
indices = indices.flatten()
|
|
|
|
hypothesis.note('indices.shape: %s' % str(indices.shape))
|
|
|
|
# The indices must be unique
|
|
hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))
|
|
|
|
# Sparsify grad
|
|
grad = grad[indices]
|
|
|
|
op = core.CreateOperator(
|
|
"RowWiseSparseAdagrad",
|
|
["param", "momentum", "indices", "grad", "lr"],
|
|
["param", "momentum"],
|
|
epsilon=epsilon,
|
|
device_option=gc)
|
|
|
|
def ref_row_wise_sparse(param, momentum, indices, grad, lr):
|
|
param_out = np.copy(param)
|
|
momentum_out = np.copy(momentum)
|
|
for i, index in enumerate(indices):
|
|
param_out[index], momentum_out[index] = self.ref_row_wise_adagrad(
|
|
param[index], momentum[index], grad[i], lr, epsilon)
|
|
return (param_out, momentum_out)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op,
|
|
[param, momentum, indices, grad, lr],
|
|
ref_row_wise_sparse)
|
|
|
|
@given(inputs=hu.tensors(n=1),
|
|
lr=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
|
allow_nan=False, allow_infinity=False),
|
|
data_strategy=st.data(),
|
|
**hu.gcs)
|
|
def test_row_wise_sparse_adagrad_empty(self, inputs, lr, epsilon,
|
|
data_strategy, gc, dc):
|
|
param = inputs[0]
|
|
lr = np.array([lr], dtype=np.float32)
|
|
|
|
momentum = data_strategy.draw(
|
|
hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
|
|
elements=hu.elements_of_type(dtype=np.float32))
|
|
)
|
|
momentum = np.abs(momentum)
|
|
|
|
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
|
|
indices = np.empty(shape=(0,), dtype=np.int64)
|
|
|
|
hypothesis.note('indices.shape: %s' % str(indices.shape))
|
|
|
|
op = core.CreateOperator(
|
|
"RowWiseSparseAdagrad",
|
|
["param", "momentum", "indices", "grad", "lr"],
|
|
["param", "momentum"],
|
|
epsilon=epsilon,
|
|
device_option=gc)
|
|
|
|
def ref_row_wise_sparse(param, momentum, indices, grad, lr):
|
|
param_out = np.copy(param)
|
|
momentum_out = np.copy(momentum)
|
|
return (param_out, momentum_out)
|
|
|
|
self.assertReferenceChecks(
|
|
gc, op,
|
|
[param, momentum, indices, grad, lr],
|
|
ref_row_wise_sparse)
|