12 #include <cuda_runtime.h> 13 #include "../_outside_libraries/helper_math.h" 19 #include "../_outside_libraries/Random123/philox.h" 20 #include "../_outside_libraries/Random123/features/compilerfeatures.h" 23 #define __DEBUG__ false 24 #define cudaCheckErrors(expr1,expr2,expr3) { cudaError_t e = expr1; int g = expr2; int p = expr3; if (e != cudaSuccess) { fprintf(stderr,"error %d %s\tfile %s\tline %d\tgeneration %d\t population %d\n", e, cudaGetErrorString(e),__FILE__,__LINE__, g,p); exit(1); } } 25 #define cudaCheckErrorsAsync(expr1,expr2,expr3) { cudaCheckErrors(expr1,expr2,expr3); if(__DEBUG__){ cudaCheckErrors(cudaDeviceSynchronize(),expr2,expr3); } } 27 __forceinline__ cudaDeviceProp set_cuda_device(
int & cuda_device){
29 cudaCheckErrorsAsync(cudaGetDeviceCount(&cudaDeviceCount),-1,-1);
30 if(cuda_device >= 0 && cuda_device < cudaDeviceCount){ cudaCheckErrors(cudaSetDevice(cuda_device),-1,-1); }
32 cudaCheckErrorsAsync(cudaGetDevice(&myDevice),-1,-1);
33 cudaDeviceProp devProp;
34 cudaCheckErrors(cudaGetDeviceProperties(&devProp, myDevice),-1,-1);
35 cuda_device = myDevice;
44 #define RNG_MEAN_BOUNDARY_NORM 6 45 #define RNG_N_BOUNDARY_POIS_BINOM 100 //binomial calculation starts to become numerically unstable for large values of N, not sure where that starts but is between 200 and 200,000 59 __host__ __device__ __forceinline__
float uint_float_01(
unsigned int in){
61 R123_CONSTEXPR
float factor = float(1.)/(UINT_MAX + float(1.));
62 R123_CONSTEXPR
float halffactor = float(0.5)*factor;
63 return in*factor + halffactor;
67 __host__ __device__ __forceinline__ uint4 Philox(int2 seed,
int k,
int step,
int population,
int round){
68 typedef r123::Philox4x32_R<10> P;
71 P::key_type key = {{seed.x, seed.y}};
72 P::ctr_type count = {{k, step, population, round}};
79 u.c = rng(count, key);
84 __host__ __device__ __forceinline__
void binom_iter(
float j,
float x,
float n,
float & emu,
float & cdf){
85 emu *= ((n+1.f-j)*x)/(j*(1-x));
89 __host__ __device__ __forceinline__
int binomcdfinv(
float r,
float mean,
float x,
float n){
90 float emu = powf(1-x,n);
91 if(emu == 1) { emu = expf(-1 * mean); }
93 if(cdf >= r){
return 0; }
95 binom_iter(1.f, x, n, emu, cdf);
if(cdf >= r){
return 1; }
96 binom_iter(2.f, x, n, emu, cdf);
if(cdf >= r){
return 2; }
97 binom_iter(3.f, x, n, emu, cdf);
if(cdf >= r){
return 3; }
98 binom_iter(4.f, x, n, emu, cdf);
if(cdf >= r){
return 4; }
99 binom_iter(5.f, x, n, emu, cdf);
if(cdf >= r){
return 5; }
100 binom_iter(6.f, x, n, emu, cdf);
if(cdf >= r){
return 6; }
101 binom_iter(7.f, x, n, emu, cdf);
if(cdf >= r){
return 7; }
102 binom_iter(8.f, x, n, emu, cdf);
if(cdf >= r){
return 8; }
103 binom_iter(9.f, x, n, emu, cdf);
if(cdf >= r){
return 9; }
104 binom_iter(10.f, x, n, emu, cdf);
if(cdf >= r){
return 10; }
105 binom_iter(11.f, x, n, emu, cdf);
if(cdf >= r || mean <= 1){
return 11; }
106 binom_iter(12.f, x, n, emu, cdf);
if(cdf >= r){
return 12; }
107 binom_iter(13.f, x, n, emu, cdf);
if(cdf >= r){
return 13; }
108 binom_iter(14.f, x, n, emu, cdf);
if(cdf >= r || mean <= 2){
return 14; }
109 binom_iter(15.f, x, n, emu, cdf);
if(cdf >= r){
return 15; }
110 binom_iter(16.f, x, n, emu, cdf);
if(cdf >= r){
return 16; }
111 binom_iter(17.f, x, n, emu, cdf);
if(cdf >= r || mean <= 3){
return 17; }
112 binom_iter(18.f, x, n, emu, cdf);
if(cdf >= r){
return 18; }
113 binom_iter(19.f, x, n, emu, cdf);
if(cdf >= r){
return 19; }
114 binom_iter(20.f, x, n, emu, cdf);
if(cdf >= r || mean <= 4){
return 20; }
115 binom_iter(21.f, x, n, emu, cdf);
if(cdf >= r){
return 21; }
116 binom_iter(22.f, x, n, emu, cdf);
if(cdf >= r || mean <= 5){
return 22; }
117 binom_iter(23.f, x, n, emu, cdf);
if(cdf >= r){
return 23; }
118 binom_iter(24.f, x, n, emu, cdf);
if(cdf >= r || mean <= 6){
return 24; }
119 binom_iter(25.f, x, n, emu, cdf);
if(cdf >= r){
return 25; }
120 binom_iter(26.f, x, n, emu, cdf);
if(cdf >= r || mean <= 7){
return 26; }
121 binom_iter(27.f, x, n, emu, cdf);
if(cdf >= r){
return 27; }
122 binom_iter(28.f, x, n, emu, cdf);
if(cdf >= r || mean <= 8){
return 28; }
123 binom_iter(29.f, x, n, emu, cdf);
if(cdf >= r){
return 29; }
124 binom_iter(30.f, x, n, emu, cdf);
if(cdf >= r || mean <= 9){
return 30; }
125 binom_iter(31.f, x, n, emu, cdf);
if(cdf >= r){
return 31; }
126 binom_iter(32.f, x, n, emu, cdf);
if(cdf >= r || mean <= 10){
return 32; }
127 binom_iter(33.f, x, n, emu, cdf);
if(cdf >= r){
return 33; }
128 binom_iter(34.f, x, n, emu, cdf);
if(cdf >= r || mean <= 11){
return 34; }
129 binom_iter(35.f, x, n, emu, cdf);
if(cdf >= r){
return 35; }
130 binom_iter(36.f, x, n, emu, cdf);
if(cdf >= r || mean <= 12){
return 36; }
131 binom_iter(37.f, x, n, emu, cdf);
if(cdf >= r){
return 37; }
132 binom_iter(38.f, x, n, emu, cdf);
if(cdf >= r || mean <= 13){
return 38; }
133 binom_iter(39.f, x, n, emu, cdf);
if(cdf >= r){
return 39; }
134 binom_iter(40.f, x, n, emu, cdf);
if(cdf >= r || mean <= 14){
return 40; }
135 binom_iter(41.f, x, n, emu, cdf);
if(cdf >= r || mean <= 15){
return 41; }
136 binom_iter(42.f, x, n, emu, cdf);
if(cdf >= r){
return 42; }
137 binom_iter(43.f, x, n, emu, cdf);
if(cdf >= r || mean <= 16){
return 43; }
138 binom_iter(44.f, x, n, emu, cdf);
if(cdf >= r){
return 44; }
139 binom_iter(45.f, x, n, emu, cdf);
if(cdf >= r || mean <= 17){
return 45; }
140 binom_iter(46.f, x, n, emu, cdf);
if(cdf >= r || mean <= 18){
return 46; }
141 binom_iter(47.f, x, n, emu, cdf);
if(cdf >= r){
return 47; }
142 binom_iter(48.f, x, n, emu, cdf);
if(cdf >= r || mean <= 19){
return 48; }
143 binom_iter(49.f, x, n, emu, cdf);
if(cdf >= r){
return 49; }
144 binom_iter(50.f, x, n, emu, cdf);
if(cdf >= r || mean <= 20){
return 50; }
145 binom_iter(51.f, x, n, emu, cdf);
if(cdf >= r || mean <= 21){
return 51; }
146 binom_iter(52.f, x, n, emu, cdf);
if(cdf >= r){
return 52; }
147 binom_iter(53.f, x, n, emu, cdf);
if(cdf >= r || mean <= 22){
return 53; }
148 binom_iter(54.f, x, n, emu, cdf);
if(cdf >= r){
return 54; }
149 binom_iter(55.f, x, n, emu, cdf);
if(cdf >= r || mean <= 23){
return 55; }
150 binom_iter(56.f, x, n, emu, cdf);
if(cdf >= r || mean <= 24){
return 56; }
151 binom_iter(57.f, x, n, emu, cdf);
if(cdf >= r){
return 57; }
152 binom_iter(58.f, x, n, emu, cdf);
if(cdf >= r || mean <= 25){
return 58; }
153 binom_iter(59.f, x, n, emu, cdf);
if(cdf >= r || mean <= 26){
return 59; }
154 binom_iter(60.f, x, n, emu, cdf);
if(cdf >= r){
return 60; }
155 binom_iter(61.f, x, n, emu, cdf);
if(cdf >= r || mean <= 27){
return 61; }
156 binom_iter(62.f, x, n, emu, cdf);
if(cdf >= r || mean <= 28){
return 62; }
157 binom_iter(63.f, x, n, emu, cdf);
if(cdf >= r){
return 63; }
158 binom_iter(64.f, x, n, emu, cdf);
if(cdf >= r || mean <= 29){
return 64; }
159 binom_iter(65.f, x, n, emu, cdf);
if(cdf >= r || mean <= 30){
return 65; }
160 binom_iter(66.f, x, n, emu, cdf);
if(cdf >= r){
return 66; }
161 binom_iter(67.f, x, n, emu, cdf);
if(cdf >= r || mean <= 31){
return 67; }
162 binom_iter(68.f, x, n, emu, cdf);
if(cdf >= r || mean <= 32){
return 68; }
163 binom_iter(69.f, x, n, emu, cdf);
if(cdf >= r){
return 69; }
168 __host__ __device__ __forceinline__
void pois_iter(
float j,
float mean,
float & emu,
float & cdf){
173 __host__ __device__ __forceinline__
int poiscdfinv(
float r,
float mean){
174 float emu = expf(-1 * mean);
176 if(cdf >= r){
return 0; }
178 pois_iter(1.f, mean, emu, cdf);
if(cdf >= r){
return 1; }
179 pois_iter(2.f, mean, emu, cdf);
if(cdf >= r){
return 2; }
180 pois_iter(3.f, mean, emu, cdf);
if(cdf >= r){
return 3; }
181 pois_iter(4.f, mean, emu, cdf);
if(cdf >= r){
return 4; }
182 pois_iter(5.f, mean, emu, cdf);
if(cdf >= r){
return 5; }
183 pois_iter(6.f, mean, emu, cdf);
if(cdf >= r){
return 6; }
184 pois_iter(7.f, mean, emu, cdf);
if(cdf >= r){
return 7; }
185 pois_iter(8.f, mean, emu, cdf);
if(cdf >= r){
return 8; }
186 pois_iter(9.f, mean, emu, cdf);
if(cdf >= r){
return 9; }
187 pois_iter(10.f, mean, emu, cdf);
if(cdf >= r){
return 10; }
188 pois_iter(11.f, mean, emu, cdf);
if(cdf >= r || mean <= 1){
return 11; }
189 pois_iter(12.f, mean, emu, cdf);
if(cdf >= r){
return 12; }
190 pois_iter(13.f, mean, emu, cdf);
if(cdf >= r){
return 13; }
191 pois_iter(14.f, mean, emu, cdf);
if(cdf >= r || mean <= 2){
return 14; }
192 pois_iter(15.f, mean, emu, cdf);
if(cdf >= r){
return 15; }
193 pois_iter(16.f, mean, emu, cdf);
if(cdf >= r){
return 16; }
194 pois_iter(17.f, mean, emu, cdf);
if(cdf >= r || mean <= 3){
return 17; }
195 pois_iter(18.f, mean, emu, cdf);
if(cdf >= r){
return 18; }
196 pois_iter(19.f, mean, emu, cdf);
if(cdf >= r){
return 19; }
197 pois_iter(20.f, mean, emu, cdf);
if(cdf >= r || mean <= 4){
return 20; }
198 pois_iter(21.f, mean, emu, cdf);
if(cdf >= r){
return 21; }
199 pois_iter(22.f, mean, emu, cdf);
if(cdf >= r || mean <= 5){
return 22; }
200 pois_iter(23.f, mean, emu, cdf);
if(cdf >= r){
return 23; }
201 pois_iter(24.f, mean, emu, cdf);
if(cdf >= r || mean <= 6){
return 24; }
202 pois_iter(25.f, mean, emu, cdf);
if(cdf >= r){
return 25; }
203 pois_iter(26.f, mean, emu, cdf);
if(cdf >= r || mean <= 7){
return 26; }
204 pois_iter(27.f, mean, emu, cdf);
if(cdf >= r){
return 27; }
205 pois_iter(28.f, mean, emu, cdf);
if(cdf >= r || mean <= 8){
return 28; }
206 pois_iter(29.f, mean, emu, cdf);
if(cdf >= r){
return 29; }
207 pois_iter(30.f, mean, emu, cdf);
if(cdf >= r || mean <= 9){
return 30; }
208 pois_iter(31.f, mean, emu, cdf);
if(cdf >= r){
return 31; }
209 pois_iter(32.f, mean, emu, cdf);
if(cdf >= r || mean <= 10){
return 32; }
210 pois_iter(33.f, mean, emu, cdf);
if(cdf >= r){
return 33; }
211 pois_iter(34.f, mean, emu, cdf);
if(cdf >= r || mean <= 11){
return 34; }
212 pois_iter(35.f, mean, emu, cdf);
if(cdf >= r){
return 35; }
213 pois_iter(36.f, mean, emu, cdf);
if(cdf >= r || mean <= 12){
return 36; }
214 pois_iter(37.f, mean, emu, cdf);
if(cdf >= r){
return 37; }
215 pois_iter(38.f, mean, emu, cdf);
if(cdf >= r || mean <= 13){
return 38; }
216 pois_iter(39.f, mean, emu, cdf);
if(cdf >= r){
return 39; }
217 pois_iter(40.f, mean, emu, cdf);
if(cdf >= r || mean <= 14){
return 40; }
218 pois_iter(41.f, mean, emu, cdf);
if(cdf >= r || mean <= 15){
return 41; }
219 pois_iter(42.f, mean, emu, cdf);
if(cdf >= r){
return 42; }
220 pois_iter(43.f, mean, emu, cdf);
if(cdf >= r || mean <= 16){
return 43; }
221 pois_iter(44.f, mean, emu, cdf);
if(cdf >= r){
return 44; }
222 pois_iter(45.f, mean, emu, cdf);
if(cdf >= r || mean <= 17){
return 45; }
223 pois_iter(46.f, mean, emu, cdf);
if(cdf >= r || mean <= 18){
return 46; }
224 pois_iter(47.f, mean, emu, cdf);
if(cdf >= r){
return 47; }
225 pois_iter(48.f, mean, emu, cdf);
if(cdf >= r || mean <= 19){
return 48; }
226 pois_iter(49.f, mean, emu, cdf);
if(cdf >= r){
return 49; }
227 pois_iter(50.f, mean, emu, cdf);
if(cdf >= r || mean <= 20){
return 50; }
228 pois_iter(51.f, mean, emu, cdf);
if(cdf >= r || mean <= 21){
return 51; }
229 pois_iter(52.f, mean, emu, cdf);
if(cdf >= r){
return 52; }
230 pois_iter(53.f, mean, emu, cdf);
if(cdf >= r || mean <= 22){
return 53; }
231 pois_iter(54.f, mean, emu, cdf);
if(cdf >= r){
return 54; }
232 pois_iter(55.f, mean, emu, cdf);
if(cdf >= r || mean <= 23){
return 55; }
233 pois_iter(56.f, mean, emu, cdf);
if(cdf >= r || mean <= 24){
return 56; }
234 pois_iter(57.f, mean, emu, cdf);
if(cdf >= r){
return 57; }
235 pois_iter(58.f, mean, emu, cdf);
if(cdf >= r || mean <= 25){
return 58; }
236 pois_iter(59.f, mean, emu, cdf);
if(cdf >= r || mean <= 26){
return 59; }
237 pois_iter(60.f, mean, emu, cdf);
if(cdf >= r){
return 60; }
238 pois_iter(61.f, mean, emu, cdf);
if(cdf >= r || mean <= 27){
return 61; }
239 pois_iter(62.f, mean, emu, cdf);
if(cdf >= r || mean <= 28){
return 62; }
240 pois_iter(63.f, mean, emu, cdf);
if(cdf >= r){
return 63; }
241 pois_iter(64.f, mean, emu, cdf);
if(cdf >= r || mean <= 29){
return 64; }
242 pois_iter(65.f, mean, emu, cdf);
if(cdf >= r || mean <= 30){
return 65; }
243 pois_iter(66.f, mean, emu, cdf);
if(cdf >= r){
return 66; }
244 pois_iter(67.f, mean, emu, cdf);
if(cdf >= r || mean <= 31){
return 67; }
245 pois_iter(68.f, mean, emu, cdf);
if(cdf >= r || mean <= 32){
return 68; }
246 pois_iter(69.f, mean, emu, cdf);
if(cdf >= r){
return 69; }
251 __host__ __device__ __forceinline__
int ApproxRandPois1(
float mean,
float var,
float p,
float N, int2 seed,
int id,
int generation,
int population){
252 uint4 i = Philox(seed,
id, generation, population, 0);
253 if(mean <= RNG_MEAN_BOUNDARY_NORM){
return poiscdfinv(uint_float_01(i.x), mean); }
254 else if(mean >= N-RNG_MEAN_BOUNDARY_NORM){
return N - poiscdfinv(uint_float_01(i.x), N-mean); }
255 return round(normcdfinv(uint_float_01(i.x))*sqrtf(var)+mean);
258 __host__ __device__ __forceinline__
int ApproxRandBinom1(
float mean,
float var,
float p,
float N, int2 seed,
int id,
int generation,
int population){
259 uint4 i = Philox(seed,
id, generation, population, 0);
260 if(mean <= RNG_MEAN_BOUNDARY_NORM){
261 if(N < RNG_N_BOUNDARY_POIS_BINOM){
return binomcdfinv(uint_float_01(i.x), mean, mean/N, N); }
else{
return poiscdfinv(uint_float_01(i.x), mean); }
263 else if(mean >= N-RNG_MEAN_BOUNDARY_NORM){
264 if(N < RNG_N_BOUNDARY_POIS_BINOM){
return N - binomcdfinv(uint_float_01(i.x), N-mean, (N-mean)/N, N); }
else{
return N - poiscdfinv(uint_float_01(i.x), N-mean); }
266 return round(normcdfinv(uint_float_01(i.x))*sqrtf(var)+mean);
270 __device__
int ApproxRandBinomHelper(
unsigned int i,
float mean,
float var,
float N);
272 __device__ __forceinline__ int4 ApproxRandBinom4(float4 mean, float4 var, float4 p,
float N, int2 seed,
int id,
int generation,
int population){
273 uint4 i = Philox(seed,
id, generation, population, 0);
274 return make_int4(ApproxRandBinomHelper(i.x, mean.x, var.x, N), ApproxRandBinomHelper(i.y, mean.y, var.y, N), ApproxRandBinomHelper(i.z, mean.z, var.z, N), ApproxRandBinomHelper(i.w, mean.w, var.w, N));