smc_device.c (6099B)
1 /* Copyright (C) 2015-2018, 2021-2023 |Méso|Star> (contact@meso-star.com) 2 * 3 * This program is free software: you can redistribute it and/or modify 4 * it under the terms of the GNU General Public License as published by 5 * the Free Software Foundation, either version 3 of the License, or 6 * (at your option) any later version. 7 * 8 * This program is distributed in the hope that it will be useful, 9 * but WITHOUT ANY WARRANTY; without even the implied warranty of 10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 * GNU General Public License for more details. 12 * 13 * You should have received a copy of the GNU General Public License 14 * along with this program. If not, see <http://www.gnu.org/licenses/>. */ 15 16 #include "smc.h" 17 #include "smc_device_c.h" 18 19 #include <star/ssp.h> 20 21 #include <rsys/logger.h> 22 #include <rsys/mem_allocator.h> 23 24 #include <omp.h> 25 26 /******************************************************************************* 27 * Helper functions 28 ******************************************************************************/ 29 static INLINE res_T 30 check_device_create_args(const struct smc_device_create_args* args) 31 { 32 if(!args || args->nthreads_hint == 0) return RES_BAD_ARG; 33 return RES_OK; 34 } 35 36 static INLINE void 37 log_msg 38 (struct smc_device* smc, 39 const enum log_type stream, 40 const char* msg, 41 va_list vargs) 42 { 43 ASSERT(smc && msg); 44 if(smc->verbose) { 45 CHK(logger_vprint(smc->logger, stream, msg, vargs) == RES_OK); 46 } 47 } 48 49 static void 50 release_rngs(struct smc_device* dev) 51 { 52 size_t i; 53 ASSERT(dev); 54 if(dev->rng_proxy) { 55 ssp_rng_proxy_ref_put(dev->rng_proxy); 56 dev->rng_proxy = NULL; 57 } 58 if(dev->rngs) { 59 ASSERT(dev->nthreads == sa_size(dev->rngs)); 60 FOR_EACH(i, 0, dev->nthreads) { 61 if(dev->rngs[i]) { 62 ssp_rng_ref_put(dev->rngs[i]); 63 dev->rngs[i] = NULL; 64 } 65 } 66 sa_release(dev->rngs); 67 dev->rngs = NULL; 68 } 69 } 70 71 static void 72 device_release(ref_T* ref) 73 { 74 struct smc_device* dev; 75 ASSERT(ref); 76 dev = CONTAINER_OF(ref, struct smc_device, ref); 77 release_rngs(dev); 78 MEM_RM(dev->allocator, dev); 79 } 80 81 /******************************************************************************* 82 * Exported functions 83 ******************************************************************************/ 84 res_T 85 smc_device_create 86 (const struct smc_device_create_args* args, 87 struct smc_device** out_dev) 88 { 89 struct smc_device* dev = NULL; 90 struct mem_allocator* allocator = &mem_default_allocator; 91 struct logger* logger = LOGGER_DEFAULT; 92 enum ssp_rng_type rng_type = SSP_RNG_THREEFRY; 93 res_T res = RES_OK; 94 95 if(!out_dev) { res = RES_BAD_ARG; goto exit; } 96 97 res = check_device_create_args(args); 98 if(res != RES_OK) goto error; 99 100 if(args->allocator) allocator = args->allocator; 101 if(args->logger) logger = args->logger; 102 if(args->rng_type != SSP_RNG_TYPE_NULL) rng_type = args->rng_type; 103 104 dev = MEM_CALLOC(allocator, 1, sizeof(struct smc_device)); 105 if(!dev) { 106 res = RES_MEM_ERR; 107 goto error; 108 } 109 ref_init(&dev->ref); 110 dev->allocator = allocator; 111 dev->logger = logger; 112 dev->verbose = args->verbose; 113 114 dev->nthreads = MMIN(args->nthreads_hint, (unsigned)omp_get_num_procs()); 115 omp_set_num_threads((int)dev->nthreads); 116 117 res = smc_device_set_rng_type(dev, rng_type); 118 if(res != RES_OK) goto error; 119 120 exit: 121 if(out_dev) *out_dev = dev; 122 return res; 123 error: 124 if(dev) { 125 SMC(device_ref_put(dev)); 126 dev = NULL; 127 } 128 goto exit; 129 } 130 131 res_T 132 smc_device_ref_get(struct smc_device* dev) 133 { 134 if(!dev) return RES_BAD_ARG; 135 ref_get(&dev->ref); 136 return RES_OK; 137 } 138 139 res_T 140 smc_device_ref_put(struct smc_device* dev) 141 { 142 if(!dev) return RES_BAD_ARG; 143 ref_put(&dev->ref, device_release); 144 return RES_OK; 145 } 146 147 res_T 148 smc_device_set_rng_type(struct smc_device* dev, const enum ssp_rng_type type) 149 { 150 size_t i; 151 res_T res = RES_OK; 152 struct ssp_rng_proxy* proxy = NULL; 153 struct ssp_rng** rngs = NULL; 154 155 if(!dev || type == SSP_RNG_TYPE_NULL) { 156 /* Skip the error block */ 157 return RES_BAD_ARG; 158 } 159 160 proxy = dev->rng_proxy; 161 rngs = dev->rngs; 162 dev->rng_proxy = NULL; 163 dev->rngs = NULL; 164 165 /* Create the new rng_proxy */ 166 res = ssp_rng_proxy_create 167 (dev->allocator, type, dev->nthreads, &dev->rng_proxy); 168 if(res != RES_OK) goto error; 169 170 /* Create the new per thread rng */ 171 dev->rngs = sa_add(dev->rngs, dev->nthreads); 172 memset(dev->rngs, 0, dev->nthreads*sizeof(struct ssp_rng*)); 173 FOR_EACH(i, 0, dev->nthreads) { 174 res = ssp_rng_proxy_create_rng(dev->rng_proxy, i, dev->rngs + i); 175 if(res != RES_OK) goto error; 176 } 177 178 /* Release the previous RNG data structure */ 179 if(proxy) SSP(rng_proxy_ref_put(proxy)); 180 if(rngs) { 181 FOR_EACH(i, 0, dev->nthreads) { 182 if(rngs[i]) SSP(rng_ref_put(rngs[i])); 183 } 184 sa_release(rngs); 185 } 186 187 exit: 188 return res; 189 error: 190 /* Restore the previous RNG type */ 191 release_rngs(dev); 192 dev->rng_proxy = proxy; 193 dev->rngs = rngs; 194 goto exit; 195 } 196 197 res_T 198 smc_device_get_rng_type(struct smc_device* dev, enum ssp_rng_type* type) 199 { 200 if(!dev || !type) return RES_BAD_ARG; 201 return ssp_rng_proxy_get_type(dev->rng_proxy, type); 202 } 203 204 res_T 205 smc_device_get_threads_count(const struct smc_device* dev, unsigned* nthreads) 206 { 207 if(!dev || !nthreads) return RES_BAD_ARG; 208 ASSERT(dev->nthreads <= UINT_MAX); 209 *nthreads = (unsigned)dev->nthreads; 210 return RES_OK; 211 } 212 213 214 /******************************************************************************* 215 * Local functions 216 ******************************************************************************/ 217 void 218 log_info(struct smc_device* smc, const char* msg, ...) 219 { 220 va_list vargs_list; 221 ASSERT(smc && msg); 222 va_start(vargs_list, msg); 223 log_msg(smc, LOG_OUTPUT, msg, vargs_list); 224 va_end(vargs_list); 225 } 226 227 void 228 log_err(struct smc_device* smc, const char* msg, ...) 229 { 230 va_list vargs_list; 231 ASSERT(smc && msg); 232 va_start(vargs_list, msg); 233 log_msg(smc, LOG_ERROR, msg, vargs_list); 234 va_end(vargs_list); 235 } 236 237 void 238 log_warn(struct smc_device* smc, const char* msg, ...) 239 { 240 va_list vargs_list; 241 ASSERT(smc && msg); 242 va_start(vargs_list, msg); 243 log_msg(smc, LOG_WARNING, msg, vargs_list); 244 va_end(vargs_list); 245 }