star-mc

Parallel estimation of Monte Carlo integrators
git clone git://git.meso-star.fr/star-mc.git
Log | Files | Refs | README | LICENSE

smc_device.c (6099B)


      1 /* Copyright (C) 2015-2018, 2021-2023 |Méso|Star> (contact@meso-star.com)
      2  *
      3  * This program is free software: you can redistribute it and/or modify
      4  * it under the terms of the GNU General Public License as published by
      5  * the Free Software Foundation, either version 3 of the License, or
      6  * (at your option) any later version.
      7  *
      8  * This program is distributed in the hope that it will be useful,
      9  * but WITHOUT ANY WARRANTY; without even the implied warranty of
     10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
     11  * GNU General Public License for more details.
     12  *
     13  * You should have received a copy of the GNU General Public License
     14  * along with this program. If not, see <http://www.gnu.org/licenses/>. */
     15 
     16 #include "smc.h"
     17 #include "smc_device_c.h"
     18 
     19 #include <star/ssp.h>
     20 
     21 #include <rsys/logger.h>
     22 #include <rsys/mem_allocator.h>
     23 
     24 #include <omp.h>
     25 
     26 /*******************************************************************************
     27  * Helper functions
     28  ******************************************************************************/
     29 static INLINE res_T
     30 check_device_create_args(const struct smc_device_create_args* args)
     31 {
     32   if(!args || args->nthreads_hint == 0) return RES_BAD_ARG;
     33   return RES_OK;
     34 }
     35 
     36 static INLINE void
     37 log_msg
     38   (struct smc_device* smc,
     39    const enum log_type stream,
     40    const char* msg,
     41    va_list vargs)
     42 {
     43   ASSERT(smc && msg);
     44   if(smc->verbose) {
     45     CHK(logger_vprint(smc->logger, stream, msg, vargs) == RES_OK);
     46   }
     47 }
     48 
     49 static void
     50 release_rngs(struct smc_device* dev)
     51 {
     52   size_t i;
     53   ASSERT(dev);
     54   if(dev->rng_proxy) {
     55     ssp_rng_proxy_ref_put(dev->rng_proxy);
     56     dev->rng_proxy = NULL;
     57   }
     58   if(dev->rngs) {
     59     ASSERT(dev->nthreads == sa_size(dev->rngs));
     60     FOR_EACH(i, 0, dev->nthreads) {
     61       if(dev->rngs[i]) {
     62         ssp_rng_ref_put(dev->rngs[i]);
     63         dev->rngs[i] = NULL;
     64       }
     65     }
     66     sa_release(dev->rngs);
     67     dev->rngs = NULL;
     68   }
     69 }
     70 
     71 static void
     72 device_release(ref_T* ref)
     73 {
     74   struct smc_device* dev;
     75   ASSERT(ref);
     76   dev = CONTAINER_OF(ref, struct smc_device, ref);
     77   release_rngs(dev);
     78   MEM_RM(dev->allocator, dev);
     79 }
     80 
     81 /*******************************************************************************
     82  * Exported functions
     83  ******************************************************************************/
     84 res_T
     85 smc_device_create
     86   (const struct smc_device_create_args* args,
     87    struct smc_device** out_dev)
     88 {
     89   struct smc_device* dev = NULL;
     90   struct mem_allocator* allocator = &mem_default_allocator;
     91   struct logger* logger = LOGGER_DEFAULT;
     92   enum ssp_rng_type rng_type = SSP_RNG_THREEFRY;
     93   res_T res = RES_OK;
     94 
     95   if(!out_dev) { res = RES_BAD_ARG; goto exit; }
     96 
     97   res = check_device_create_args(args);
     98   if(res != RES_OK) goto error;
     99 
    100   if(args->allocator) allocator = args->allocator;
    101   if(args->logger) logger = args->logger;
    102   if(args->rng_type != SSP_RNG_TYPE_NULL) rng_type = args->rng_type;
    103 
    104   dev = MEM_CALLOC(allocator, 1, sizeof(struct smc_device));
    105   if(!dev) {
    106     res = RES_MEM_ERR;
    107     goto error;
    108   }
    109   ref_init(&dev->ref);
    110   dev->allocator = allocator;
    111   dev->logger = logger;
    112   dev->verbose = args->verbose;
    113 
    114   dev->nthreads = MMIN(args->nthreads_hint, (unsigned)omp_get_num_procs());
    115   omp_set_num_threads((int)dev->nthreads);
    116 
    117   res = smc_device_set_rng_type(dev, rng_type);
    118   if(res != RES_OK) goto error;
    119 
    120 exit:
    121   if(out_dev) *out_dev = dev;
    122   return res;
    123 error:
    124   if(dev) {
    125     SMC(device_ref_put(dev));
    126     dev = NULL;
    127   }
    128   goto exit;
    129 }
    130 
    131 res_T
    132 smc_device_ref_get(struct smc_device* dev)
    133 {
    134   if(!dev) return RES_BAD_ARG;
    135   ref_get(&dev->ref);
    136   return RES_OK;
    137 }
    138 
    139 res_T
    140 smc_device_ref_put(struct smc_device* dev)
    141 {
    142   if(!dev) return RES_BAD_ARG;
    143   ref_put(&dev->ref, device_release);
    144   return RES_OK;
    145 }
    146 
    147 res_T
    148 smc_device_set_rng_type(struct smc_device* dev, const enum ssp_rng_type type)
    149 {
    150   size_t i;
    151   res_T res = RES_OK;
    152   struct ssp_rng_proxy* proxy = NULL;
    153   struct ssp_rng** rngs = NULL;
    154 
    155   if(!dev || type == SSP_RNG_TYPE_NULL) {
    156     /* Skip the error block */
    157     return RES_BAD_ARG;
    158   }
    159 
    160   proxy = dev->rng_proxy;
    161   rngs = dev->rngs;
    162   dev->rng_proxy = NULL;
    163   dev->rngs = NULL;
    164 
    165   /* Create the new rng_proxy */
    166   res = ssp_rng_proxy_create
    167     (dev->allocator, type, dev->nthreads, &dev->rng_proxy);
    168   if(res != RES_OK) goto error;
    169 
    170   /* Create the new per thread rng */
    171   dev->rngs = sa_add(dev->rngs, dev->nthreads);
    172   memset(dev->rngs, 0, dev->nthreads*sizeof(struct ssp_rng*));
    173   FOR_EACH(i, 0, dev->nthreads) {
    174     res = ssp_rng_proxy_create_rng(dev->rng_proxy, i, dev->rngs + i);
    175     if(res != RES_OK) goto error;
    176   }
    177 
    178   /* Release the previous RNG data structure */
    179   if(proxy) SSP(rng_proxy_ref_put(proxy));
    180   if(rngs) {
    181     FOR_EACH(i, 0, dev->nthreads) {
    182       if(rngs[i]) SSP(rng_ref_put(rngs[i]));
    183     }
    184     sa_release(rngs);
    185   }
    186 
    187 exit:
    188   return res;
    189 error:
    190   /* Restore the previous RNG type */
    191   release_rngs(dev);
    192   dev->rng_proxy = proxy;
    193   dev->rngs = rngs;
    194   goto exit;
    195 }
    196 
    197 res_T
    198 smc_device_get_rng_type(struct smc_device* dev, enum ssp_rng_type* type)
    199 {
    200   if(!dev || !type) return RES_BAD_ARG;
    201   return ssp_rng_proxy_get_type(dev->rng_proxy, type);
    202 }
    203 
    204 res_T
    205 smc_device_get_threads_count(const struct smc_device* dev, unsigned* nthreads)
    206 {
    207   if(!dev || !nthreads) return RES_BAD_ARG;
    208   ASSERT(dev->nthreads <= UINT_MAX);
    209   *nthreads = (unsigned)dev->nthreads;
    210   return RES_OK;
    211 }
    212 
    213 
    214 /*******************************************************************************
    215  * Local functions
    216  ******************************************************************************/
    217 void
    218 log_info(struct smc_device* smc, const char* msg, ...)
    219 {
    220   va_list vargs_list;
    221   ASSERT(smc && msg);
    222   va_start(vargs_list, msg);
    223   log_msg(smc, LOG_OUTPUT, msg, vargs_list);
    224   va_end(vargs_list);
    225 }
    226 
    227 void
    228 log_err(struct smc_device* smc, const char* msg, ...)
    229 {
    230   va_list vargs_list;
    231   ASSERT(smc && msg);
    232   va_start(vargs_list, msg);
    233   log_msg(smc, LOG_ERROR, msg, vargs_list);
    234   va_end(vargs_list);
    235 }
    236 
    237 void
    238 log_warn(struct smc_device* smc, const char* msg, ...)
    239 {
    240   va_list vargs_list;
    241   ASSERT(smc && msg);
    242   va_start(vargs_list, msg);
    243   log_msg(smc, LOG_WARNING, msg, vargs_list);
    244   va_end(vargs_list);
    245 }