star-sp

Random number generators and distributions
git clone git://git.meso-star.fr/star-sp.git
Log | Files | Refs | README | LICENSE

commit 2ad6e27171a42a7aa5845f68e7547e6b7a2e69c6
parent 0613381152e7772d8b5d350a7870a405642da296
Author: Vincent Forest <vincent.forest@meso-star.com>
Date:   Fri,  3 Dec 2021 16:49:31 +0100

Fix the ssp_rng_proxy_read function

Cached RNG states were not cleaned when the proxy state was updated. As
a result, proxy-managed RNGs generated random numbers from the previous
proxy states saved into the cache.

Diffstat:
Msrc/ssp_rng_proxy.c | 49++++++++++++++++++++++++++++++++++++++++---------
Msrc/test_ssp_rng_proxy.c | 118++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
2 files changed, 154 insertions(+), 13 deletions(-)

diff --git a/src/ssp_rng_proxy.c b/src/ssp_rng_proxy.c @@ -57,7 +57,7 @@ struct rng_state_cache { CLBK(rng_proxy_cb_T, ARG1(const struct ssp_rng_proxy*)); enum rng_proxy_sig { - RNG_PROXY_SIG_READ, + RNG_PROXY_SIG_SET_STATE, RNG_PROXY_SIGS_COUNT__ }; @@ -106,7 +106,7 @@ struct ssp_rng_proxy { * |######| Bucket 0 | ... | Bucket N-1 |####| Bucket 0 | ... * |######| 1st pool | | 1st pool |####| 2nd pool | * +------+------------+- -+------------+----+------------+- - * \ / \_________sequence_size_______/ / + * \ / \_________sequence_size_______/ / * sequence \________sequence_pitch__________/ * offset */ @@ -179,6 +179,17 @@ rng_state_cache_release(struct rng_state_cache* cache) darray_char_release(&cache->state_scratch); } +static void +rng_state_cache_clear(struct rng_state_cache* cache) +{ + ASSERT(cache->stream); + rewind(cache->stream); + cache->read = cache->write = ftell(cache->stream); + cache->nstates = 0; + cache->no_wstream = 0; + cache->no_rstream = 0; +} + static char rng_state_cache_is_empty(struct rng_state_cache* cache) { @@ -305,11 +316,11 @@ struct rng_bucket { struct ssp_rng_proxy* proxy; /* The RNG proxy */ size_t name; /* Unique bucket identifier in [0, #buckets) */ size_t count; /* Remaining unique random numbers in `pool' */ - rng_proxy_cb_T cb_on_proxy_read; + rng_proxy_cb_T cb_on_proxy_set_state; }; static void -rng_bucket_on_proxy_read(const struct ssp_rng_proxy* proxy, void* ctx) +rng_bucket_on_proxy_set_state(const struct ssp_rng_proxy* proxy, void* ctx) { struct rng_bucket* rng = (struct rng_bucket*)ctx; ASSERT(proxy && ctx && rng->proxy == proxy); @@ -395,7 +406,7 @@ rng_bucket_release(void* data) struct rng_bucket* rng = (struct rng_bucket*)data; ASSERT(data && rng->proxy); ATOMIC_SET(&rng->proxy->buckets[rng->name], 0); - CLBK_DISCONNECT(&rng->cb_on_proxy_read); + CLBK_DISCONNECT(&rng->cb_on_proxy_set_state); SSP(rng_proxy_ref_put(rng->proxy)); } @@ -520,6 +531,19 @@ error: goto exit; } +static void +rng_proxy_clear_caches(struct ssp_rng_proxy* proxy) +{ + size_t ibucket; + ASSERT(proxy); + + mutex_lock(proxy->mutex); + FOR_EACH(ibucket, 0, sa_size(proxy->pools)) { + rng_state_cache_clear(proxy->states+ibucket); + } + mutex_unlock(proxy->mutex); +} + void rng_proxy_release(ref_T* ref) { @@ -683,7 +707,12 @@ ssp_rng_proxy_read(struct ssp_rng_proxy* proxy, FILE* stream) mutex_unlock(proxy->mutex); if(res != RES_OK) return res; - SIG_BROADCAST(proxy->signals+RNG_PROXY_SIG_READ, rng_proxy_cb_T, ARG1(proxy)); + /* Discard the cached RNG states */ + rng_proxy_clear_caches(proxy); + + /* Notify to bucket RNGs that the proxy RNG state was updated */ + SIG_BROADCAST + (proxy->signals+RNG_PROXY_SIG_SET_STATE, rng_proxy_cb_T, ARG1(proxy)); return RES_OK; } @@ -752,9 +781,11 @@ ssp_rng_proxy_create_rng /* The bucket RNG listens the "write" signal of the proxy to reset its * internal RNs counter on "write" invocation. */ - CLBK_INIT(&bucket->cb_on_proxy_read); - CLBK_SETUP(&bucket->cb_on_proxy_read, rng_bucket_on_proxy_read, bucket); - SIG_CONNECT_CLBK(proxy->signals+RNG_PROXY_SIG_READ, &bucket->cb_on_proxy_read); + CLBK_INIT(&bucket->cb_on_proxy_set_state); + CLBK_SETUP + (&bucket->cb_on_proxy_set_state, rng_bucket_on_proxy_set_state, bucket); + SIG_CONNECT_CLBK + (proxy->signals+RNG_PROXY_SIG_SET_STATE, &bucket->cb_on_proxy_set_state); exit: if(out_rng) *out_rng = rng; diff --git a/src/test_ssp_rng_proxy.c b/src/test_ssp_rng_proxy.c @@ -395,25 +395,28 @@ test_read(void) CHK(ssp_rng_create(NULL, SSP_RNG_MT19937_64, &rng) == RES_OK); CHK(ssp_rng_discard(rng, COUNT) == RES_OK); + /* Create a RNG state */ stream = tmpfile(); CHK(stream != NULL); CHK(ssp_rng_write(rng, stream) == RES_OK); rewind(stream); + /* Create a proxy from the RNG state */ CHK(ssp_rng_proxy_create(NULL, SSP_RNG_MT19937_64, 4, &proxy) == RES_OK); CHK(ssp_rng_proxy_read(NULL, NULL) == RES_BAD_ARG); CHK(ssp_rng_proxy_read(proxy, NULL) == RES_BAD_ARG); CHK(ssp_rng_proxy_read(NULL, stream) == RES_BAD_ARG); CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK); + /* Create the list of RNG managed by the proxy */ FOR_EACH(i, 0, 4) CHK(ssp_rng_proxy_create_rng(proxy, i, &rng1[i]) == RES_OK); + /* Check the random number sequences */ r[0] = ssp_rng_get(rng); CHK(r[0] == ssp_rng_get(rng1[0])); CHK(r[0] != ssp_rng_get(rng1[1])); CHK(r[0] != ssp_rng_get(rng1[2])); CHK(r[0] != ssp_rng_get(rng1[3])); - FOR_EACH(i, 0, NRANDS) { FOR_EACH(j, 0, 4) { r[j] = ssp_rng_get(rng1[j]); @@ -424,7 +427,113 @@ test_read(void) CHK(ssp_rng_proxy_ref_put(proxy) == RES_OK); CHK(ssp_rng_ref_put(rng) == RES_OK); FOR_EACH(i, 0, 4) CHK(ssp_rng_ref_put(rng1[i]) == RES_OK); - fclose(stream); + + CHK(fclose(stream) == 0); +} + +static void +test_read_with_cached_states(void) +{ + struct ssp_rng_proxy_create2_args args = SSP_RNG_PROXY_CREATE2_ARGS_NULL; + struct ssp_rng_proxy* proxy; + struct ssp_rng* rng; + struct ssp_rng* rng0; + struct ssp_rng* rng1; + FILE* stream; + size_t iseq; + size_t i; + + CHK(ssp_rng_create(NULL, SSP_RNG_MT19937_64, &rng) == RES_OK); + + /* Create a proxy with a very small sequence size of 2 RNs per bucket in + * order to maximize the number of cached states. Furthermore, use the + * Mersenne Twister RNG type since its internal state is the greatest one of + * the proposed builtin type and is thus the one that will fill quickly the + * cache stream. */ + args.type = SSP_RNG_MT19937_64; + args.sequence_size = 4; + args.sequence_pitch = 4; + args.nbuckets = 2; + CHK(ssp_rng_proxy_create2(NULL, &args, &proxy) == RES_OK); + CHK(ssp_rng_proxy_create_rng(proxy, 0, &rng0) == RES_OK); + CHK(ssp_rng_proxy_create_rng(proxy, 1, &rng1) == RES_OK); + + /* Check RNG states */ + CHK(ssp_rng_get(rng) == ssp_rng_get(rng0)); + CHK(ssp_rng_get(rng) == ssp_rng_get(rng0)); + CHK(ssp_rng_get(rng) == ssp_rng_get(rng1)); + CHK(ssp_rng_get(rng) == ssp_rng_get(rng1)); + + /* Discard several RNs for the 1st RNG to cache several states for the rng1 */ + CHK(ssp_rng_discard(rng0, 100) == RES_OK); + + /* Generate a RNG state */ + stream = tmpfile(); + CHK(stream != NULL); + CHK(ssp_rng_discard(rng, 100000) == RES_OK); + CHK(ssp_rng_write(rng, stream) == RES_OK); + rewind(stream); + + /* Setup the proxy state and check that the RNGs managed by the proxy now use + * the new state properly, even though RNG states were previously cached */ + CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK); + FOR_EACH(iseq, 0, 100) { + FOR_EACH(i, 0, args.sequence_size) { + if(i < args.sequence_size/2) { + CHK(ssp_rng_get(rng) == ssp_rng_get(rng0)); + } else { + CHK(ssp_rng_get(rng) == ssp_rng_get(rng1)); + } + } + } + + /* Discard several RNs for the first RNG only to make under pressure the + * cache stream of 'rng1'. The cache stream limit is set to 32 MB and the + * size of a Mersenne Twister RNG state is greater than 6 KB. Consquently, + * ~5500 RNG states will exceed the cache stream, i.e. 5500*2 = 11000 + * random generations (since there is 2 RNs per bucket). Above this limit, + * 'rng1' will not rely anymore on the proxy RNG to manage its state. */ + CHK(ssp_rng_discard(rng0, 20000) == RES_OK); + + /* Generate a new RNG state */ + rewind(stream); + CHK(ssp_rng_discard(rng, 100000) == RES_OK); + CHK(ssp_rng_write(rng, stream) == RES_OK); + rewind(stream); + + /* Now setup a new proxy state and check that the RNGs managed by the proxy + * correctly use the new state */ + CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK); + FOR_EACH(iseq, 0, 100) { + FOR_EACH(i, 0, args.sequence_size) { + if(i < args.sequence_size/2) { + CHK(ssp_rng_get(rng) == ssp_rng_get(rng0)); + } else { + CHK(ssp_rng_get(rng) == ssp_rng_get(rng1)); + } + } + } + + /* Discard several RNs for the 1st RNG to cache several states for the rng1 */ + CHK(ssp_rng_discard(rng0, 100) == RES_OK); + + /* Finally verify that the cache was reset on proxy read */ + FOR_EACH(iseq, 0, 100) { + FOR_EACH(i, 0, args.sequence_size) { + if(i < args.sequence_size/2) { + CHK(ssp_rng_discard(rng, 1) == RES_OK); + } else { + CHK(ssp_rng_get(rng) == ssp_rng_get(rng1)); + } + } + } + + CHK(ssp_rng_proxy_ref_put(proxy) == RES_OK); + CHK(ssp_rng_ref_put(rng0) == RES_OK); + CHK(ssp_rng_ref_put(rng1) == RES_OK); + CHK(ssp_rng_ref_put(rng) == RES_OK); + + CHK(fclose(stream) == 0); } static void @@ -460,7 +569,7 @@ test_write(void) CHK(ssp_rng_ref_put(rng0) == RES_OK); CHK(ssp_rng_ref_put(rng1) == RES_OK); - fclose(stream); + CHK(fclose(stream) == 0); } static void @@ -484,7 +593,7 @@ test_cache(void) args.sequence_pitch = 4; args.nbuckets = 2; - /* Simply test that the RNs generated by the proxy are the same thant the + /* Simply test that the RNs generated by the proxy are the same than the * ones generated by a regular RNG. Since each RNG invocation are * interleaved, the cache pressure is very low, at most 1 RNG state is * cached. */ @@ -550,6 +659,7 @@ main(int argc, char** argv) test_multi_proxies(); test_proxy_from_rng(); test_read(); + test_read_with_cached_states(); test_write(); test_cache(); CHK(mem_allocated_size() == 0);