Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/mersenne_twister.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class MersenneTwister : public PseudoRandomNumberGenerator {
int m_state_index = 0;


std::vector<uint64_t> initializeRecurrenceState(const uint64_t seed);
std::vector<uint64_t> initializeRecurrenceState(const uint64_t seed) const;
uint64_t generateNextStateValue();
uint64_t tempering(uint64_t);
uint64_t tempering(uint64_t) const;

public:
MersenneTwister();
Expand Down
2 changes: 1 addition & 1 deletion include/prng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PseudoRandomNumberGenerator
PseudoRandomNumberGenerator(const uint64_t seed, const uint64_t minimum_value, const uint64_t maximum_value);
virtual ~PseudoRandomNumberGenerator() {}

uint64_t generateCryptographicallyInsecureSeed();
static uint64_t generateCryptographicallyInsecureSeed();

virtual uint64_t generateRandomValue() = 0;
double generateUnitNormalRandomValue();
Expand Down
18 changes: 9 additions & 9 deletions src/prngs/mersenne_twister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ MersenneTwister::MersenneTwister(const uint64_t seed)
m_lower_bit_mask(std::numeric_limits<uint64_t>::max() >> (m_w - m_r)),
m_recurrence_state(initializeRecurrenceState(m_seed)) {}

std::vector<uint64_t> MersenneTwister::initializeRecurrenceState(const uint64_t seed) {
std::vector<uint64_t> state(m_n);
std::vector<uint64_t> MersenneTwister::initializeRecurrenceState(const uint64_t seed) const {
std::vector<uint64_t> state(static_cast<size_t>(m_n));
state[0] = seed;
for (int i = 1; i < m_n; ++i) {
for (size_t i = 1; i < static_cast<size_t>(m_n); ++i) {
state[i] = (m_f * (state[i - 1] ^ (state[i - 1] >> (m_w - 2)))) + i;
}
return state;
Expand All @@ -65,12 +65,12 @@ uint64_t MersenneTwister::generateNextStateValue(){
lower_index += m_n;
}

uint64_t upper_part = m_recurrence_state[upper_index] & m_upper_bit_mask;
uint64_t lower_part = m_recurrence_state[lower_index] & m_lower_bit_mask;
uint64_t upper_part = m_recurrence_state[static_cast<size_t>(upper_index)] & m_upper_bit_mask;
uint64_t lower_part = m_recurrence_state[static_cast<size_t>(lower_index)] & m_lower_bit_mask;
uint64_t concatenated_value = upper_part | lower_part;

uint64_t matrix_mul_result = concatenated_value >> 1;
if (concatenated_value & 0b1){
if ((concatenated_value & 0b1) == 0b1){
matrix_mul_result ^= m_a;
}

Expand All @@ -79,11 +79,11 @@ uint64_t MersenneTwister::generateNextStateValue(){
if (middle_index < 0){
middle_index += m_n;
}
uint64_t middle_value = m_recurrence_state[middle_index];
uint64_t middle_value = m_recurrence_state[static_cast<size_t>(middle_index)];

uint64_t state_value = matrix_mul_result ^ middle_value;

m_recurrence_state[k] = state_value;
m_recurrence_state[static_cast<size_t>(k)] = state_value;

m_state_index++;
if (m_state_index >= m_n){
Expand All @@ -94,7 +94,7 @@ uint64_t MersenneTwister::generateNextStateValue(){

}

uint64_t MersenneTwister::tempering(uint64_t val){
uint64_t MersenneTwister::tempering(uint64_t val) const {
uint64_t tempered_value = val ^ (val >> m_u);
tempered_value ^= ((tempered_value << m_s) & m_b);
tempered_value ^= ((tempered_value << m_t) & m_c);
Expand Down
2 changes: 1 addition & 1 deletion src/prngs/prng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ std::uint64_t PseudoRandomNumberGenerator::generateCryptographicallyInsecureSeed
auto current_time = std::chrono::system_clock::now();
auto current_time_duration = current_time.time_since_epoch(); // convert a bare time to a duration
auto time_as_milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(current_time_duration); // get the duration as a value in milliseconds
uint64_t seed_from_milliseconds = time_as_milliseconds.count(); // convert the milliseconds to a bare integer and use that as our seed
auto seed_from_milliseconds = static_cast<uint64_t>(time_as_milliseconds.count()); // convert the milliseconds to a bare integer and use that as our seed
return seed_from_milliseconds;
};

Expand Down