Finish implementation of prng related functions

This commit is contained in:
Adam Paszke
2016-09-29 09:39:00 -07:00
parent 2d8c2972ae
commit 3f7ab95890
8 changed files with 231 additions and 60 deletions

View File

@ -144,6 +144,69 @@ PyObject * THCPModule_getDriverVersion(PyObject *self)
return PyLong_FromLong((long) driverVersion);
}
PyObject * THCPModule_getRNGState(PyObject *_unused)
{
HANDLE_TH_ERRORS
THByteTensorPtr _t = THByteTensor_new();
THCRandom_getRNGState(state, _t.get());
PyObject *_ret = THPByteTensor_New(_t.get());
_t.release();
return _ret;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *_new_rng_state)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPByteTensor_Check(_new_rng_state), "set_rng_state expects a "
"torch.ByteTensor, but got %s", THPUtils_typename(_new_rng_state));
THByteTensor *new_rng_state = ((THPByteTensor*)_new_rng_state)->cdata;
THCRandom_setRNGState(state, new_rng_state);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_manualSeed(PyObject *_unused, PyObject *seed)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
"but got %s", THPUtils_typename(seed));
THCRandom_manualSeed(state, THPUtils_unpackLong(seed));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_manualSeedAll(PyObject *_unused, PyObject *seed)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
"but got %s", THPUtils_typename(seed));
THCRandom_manualSeedAll(state, THPUtils_unpackLong(seed));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_seed(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_seed(state));
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_seedAll(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_seedAll(state));
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_initialSeed(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_initialSeed(state));
END_HANDLE_TH_ERRORS
}
////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////