mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
Finish implementation of prng related functions
This commit is contained in:
@ -144,6 +144,69 @@ PyObject * THCPModule_getDriverVersion(PyObject *self)
|
||||
return PyLong_FromLong((long) driverVersion);
|
||||
}
|
||||
|
||||
PyObject * THCPModule_getRNGState(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THByteTensorPtr _t = THByteTensor_new();
|
||||
THCRandom_getRNGState(state, _t.get());
|
||||
PyObject *_ret = THPByteTensor_New(_t.get());
|
||||
_t.release();
|
||||
return _ret;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *_new_rng_state)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPByteTensor_Check(_new_rng_state), "set_rng_state expects a "
|
||||
"torch.ByteTensor, but got %s", THPUtils_typename(_new_rng_state));
|
||||
THByteTensor *new_rng_state = ((THPByteTensor*)_new_rng_state)->cdata;
|
||||
THCRandom_setRNGState(state, new_rng_state);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_manualSeed(PyObject *_unused, PyObject *seed)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
|
||||
"but got %s", THPUtils_typename(seed));
|
||||
THCRandom_manualSeed(state, THPUtils_unpackLong(seed));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_manualSeedAll(PyObject *_unused, PyObject *seed)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
|
||||
"but got %s", THPUtils_typename(seed));
|
||||
THCRandom_manualSeedAll(state, THPUtils_unpackLong(seed));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_seed(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromUnsignedLong(THCRandom_seed(state));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_seedAll(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromUnsignedLong(THCRandom_seedAll(state));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_initialSeed(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromUnsignedLong(THCRandom_initialSeed(state));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Cuda module initialization
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user