Parallel PopGen Package
shared.cuh
1 /*
2  * shared.cuh
3  *
4  * Author: David Lawrie
5  * for cuda and rand functions used by both go_fish and by sfs
6  */
7 
8 #ifndef SHARED_CUH_
9 #define SHARED_CUH_
10 
11 //includes below in sfs & go_fish
12 #include <cuda_runtime.h>
13 #include "../_outside_libraries/helper_math.h"
14 #include <limits.h>
15 #include <math.h>
16 #include <iostream>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include "../_outside_libraries/Random123/philox.h"
20 #include "../_outside_libraries/Random123/features/compilerfeatures.h"
21 
22 /* ----- cuda error checking & device setting ----- */
23 #define __DEBUG__ false
24 #define cudaCheckErrors(expr1,expr2,expr3) { cudaError_t e = expr1; int g = expr2; int p = expr3; if (e != cudaSuccess) { fprintf(stderr,"error %d %s\tfile %s\tline %d\tgeneration %d\t population %d\n", e, cudaGetErrorString(e),__FILE__,__LINE__, g,p); exit(1); } }
25 #define cudaCheckErrorsAsync(expr1,expr2,expr3) { cudaCheckErrors(expr1,expr2,expr3); if(__DEBUG__){ cudaCheckErrors(cudaDeviceSynchronize(),expr2,expr3); } }
26 
27 __forceinline__ cudaDeviceProp set_cuda_device(int & cuda_device){
28  int cudaDeviceCount;
29  cudaCheckErrorsAsync(cudaGetDeviceCount(&cudaDeviceCount),-1,-1);
30  if(cuda_device >= 0 && cuda_device < cudaDeviceCount){ cudaCheckErrors(cudaSetDevice(cuda_device),-1,-1); } //unless user specifies, driver auto-magically selects free GPU to run on
31  int myDevice;
32  cudaCheckErrorsAsync(cudaGetDevice(&myDevice),-1,-1);
33  cudaDeviceProp devProp;
34  cudaCheckErrors(cudaGetDeviceProperties(&devProp, myDevice),-1,-1);
35  cuda_device = myDevice;
36  return devProp;
37 }
38 
39 /* ----- end cuda error checking ----- */
40 
41 /* ----- random number generation ----- */
42 
43 namespace RNG{
44 #define RNG_MEAN_BOUNDARY_NORM 6
45 #define RNG_N_BOUNDARY_POIS_BINOM 100 //binomial calculation starts to become numerically unstable for large values of N, not sure where that starts but is between 200 and 200,000
46 
47 // uint_float_01: Input is a W-bit integer (unsigned). It is multiplied
48 // by Float(2^-W) and added to Float(2^(-W-1)). A good compiler should
49 // optimize it down to an int-to-float conversion followed by a multiply
50 // and an add, which might be fused, depending on the architecture.
51 //
52 // If the input is a uniformly distributed integer, then the
53 // result is a uniformly distributed floating point number in [0, 1].
54 // The result is never exactly 0.0.
55 // The smallest value returned is 2^-W.
56 // Let M be the number of mantissa bits in Float.
57 // If W>M then the largest value retured is 1.0.
58 // If W<=M then the largest value returned is the largest Float less than 1.0.
59 __host__ __device__ __forceinline__ float uint_float_01(unsigned int in){
60  //(mostly) stolen from Philox code "uniform.hpp"
61  R123_CONSTEXPR float factor = float(1.)/(UINT_MAX + float(1.));
62  R123_CONSTEXPR float halffactor = float(0.5)*factor;
63  return in*factor + halffactor;
64 }
65 
66 
67 __host__ __device__ __forceinline__ uint4 Philox(int2 seed, int k, int step, int population, int round){
68  typedef r123::Philox4x32_R<10> P; //can change the 10 rounds of bijection down to 8 (lowest safe limit) to get possible extra speed!
69  P rng;
70 
71  P::key_type key = {{seed.x, seed.y}}; //random int to set key space + seed
72  P::ctr_type count = {{k, step, population, round}};
73 
74  union {
75  P::ctr_type c;
76  uint4 i;
77  }u;
78 
79  u.c = rng(count, key);
80 
81  return u.i;
82 }
83 
84 __host__ __device__ __forceinline__ void binom_iter(float j, float x, float n, float & emu, float & cdf){
85  emu *= ((n+1.f-j)*x)/(j*(1-x));
86  cdf += emu;
87 }
88 
89 __host__ __device__ __forceinline__ int binomcdfinv(float r, float mean, float x, float n){
90  float emu = powf(1-x,n);
91  if(emu == 1) { emu = expf(-1 * mean); }
92  float cdf = emu;
93  if(cdf >= r){ return 0; }
94 
95  binom_iter(1.f, x, n, emu, cdf); if(cdf >= r){ return 1; }
96  binom_iter(2.f, x, n, emu, cdf); if(cdf >= r){ return 2; }
97  binom_iter(3.f, x, n, emu, cdf); if(cdf >= r){ return 3; }
98  binom_iter(4.f, x, n, emu, cdf); if(cdf >= r){ return 4; }
99  binom_iter(5.f, x, n, emu, cdf); if(cdf >= r){ return 5; }
100  binom_iter(6.f, x, n, emu, cdf); if(cdf >= r){ return 6; }
101  binom_iter(7.f, x, n, emu, cdf); if(cdf >= r){ return 7; }
102  binom_iter(8.f, x, n, emu, cdf); if(cdf >= r){ return 8; }
103  binom_iter(9.f, x, n, emu, cdf); if(cdf >= r){ return 9; }
104  binom_iter(10.f, x, n, emu, cdf); if(cdf >= r){ return 10; }
105  binom_iter(11.f, x, n, emu, cdf); if(cdf >= r || mean <= 1){ return 11; }
106  binom_iter(12.f, x, n, emu, cdf); if(cdf >= r){ return 12; }
107  binom_iter(13.f, x, n, emu, cdf); if(cdf >= r){ return 13; }
108  binom_iter(14.f, x, n, emu, cdf); if(cdf >= r || mean <= 2){ return 14; }
109  binom_iter(15.f, x, n, emu, cdf); if(cdf >= r){ return 15; }
110  binom_iter(16.f, x, n, emu, cdf); if(cdf >= r){ return 16; }
111  binom_iter(17.f, x, n, emu, cdf); if(cdf >= r || mean <= 3){ return 17; }
112  binom_iter(18.f, x, n, emu, cdf); if(cdf >= r){ return 18; }
113  binom_iter(19.f, x, n, emu, cdf); if(cdf >= r){ return 19; }
114  binom_iter(20.f, x, n, emu, cdf); if(cdf >= r || mean <= 4){ return 20; }
115  binom_iter(21.f, x, n, emu, cdf); if(cdf >= r){ return 21; }
116  binom_iter(22.f, x, n, emu, cdf); if(cdf >= r || mean <= 5){ return 22; }
117  binom_iter(23.f, x, n, emu, cdf); if(cdf >= r){ return 23; }
118  binom_iter(24.f, x, n, emu, cdf); if(cdf >= r || mean <= 6){ return 24; }
119  binom_iter(25.f, x, n, emu, cdf); if(cdf >= r){ return 25; }
120  binom_iter(26.f, x, n, emu, cdf); if(cdf >= r || mean <= 7){ return 26; }
121  binom_iter(27.f, x, n, emu, cdf); if(cdf >= r){ return 27; }
122  binom_iter(28.f, x, n, emu, cdf); if(cdf >= r || mean <= 8){ return 28; }
123  binom_iter(29.f, x, n, emu, cdf); if(cdf >= r){ return 29; }
124  binom_iter(30.f, x, n, emu, cdf); if(cdf >= r || mean <= 9){ return 30; }
125  binom_iter(31.f, x, n, emu, cdf); if(cdf >= r){ return 31; }
126  binom_iter(32.f, x, n, emu, cdf); if(cdf >= r || mean <= 10){ return 32; }
127  binom_iter(33.f, x, n, emu, cdf); if(cdf >= r){ return 33; }
128  binom_iter(34.f, x, n, emu, cdf); if(cdf >= r || mean <= 11){ return 34; }
129  binom_iter(35.f, x, n, emu, cdf); if(cdf >= r){ return 35; }
130  binom_iter(36.f, x, n, emu, cdf); if(cdf >= r || mean <= 12){ return 36; }
131  binom_iter(37.f, x, n, emu, cdf); if(cdf >= r){ return 37; }
132  binom_iter(38.f, x, n, emu, cdf); if(cdf >= r || mean <= 13){ return 38; }
133  binom_iter(39.f, x, n, emu, cdf); if(cdf >= r){ return 39; }
134  binom_iter(40.f, x, n, emu, cdf); if(cdf >= r || mean <= 14){ return 40; }
135  binom_iter(41.f, x, n, emu, cdf); if(cdf >= r || mean <= 15){ return 41; }
136  binom_iter(42.f, x, n, emu, cdf); if(cdf >= r){ return 42; }
137  binom_iter(43.f, x, n, emu, cdf); if(cdf >= r || mean <= 16){ return 43; }
138  binom_iter(44.f, x, n, emu, cdf); if(cdf >= r){ return 44; }
139  binom_iter(45.f, x, n, emu, cdf); if(cdf >= r || mean <= 17){ return 45; }
140  binom_iter(46.f, x, n, emu, cdf); if(cdf >= r || mean <= 18){ return 46; }
141  binom_iter(47.f, x, n, emu, cdf); if(cdf >= r){ return 47; }
142  binom_iter(48.f, x, n, emu, cdf); if(cdf >= r || mean <= 19){ return 48; }
143  binom_iter(49.f, x, n, emu, cdf); if(cdf >= r){ return 49; }
144  binom_iter(50.f, x, n, emu, cdf); if(cdf >= r || mean <= 20){ return 50; }
145  binom_iter(51.f, x, n, emu, cdf); if(cdf >= r || mean <= 21){ return 51; }
146  binom_iter(52.f, x, n, emu, cdf); if(cdf >= r){ return 52; }
147  binom_iter(53.f, x, n, emu, cdf); if(cdf >= r || mean <= 22){ return 53; }
148  binom_iter(54.f, x, n, emu, cdf); if(cdf >= r){ return 54; }
149  binom_iter(55.f, x, n, emu, cdf); if(cdf >= r || mean <= 23){ return 55; }
150  binom_iter(56.f, x, n, emu, cdf); if(cdf >= r || mean <= 24){ return 56; }
151  binom_iter(57.f, x, n, emu, cdf); if(cdf >= r){ return 57; }
152  binom_iter(58.f, x, n, emu, cdf); if(cdf >= r || mean <= 25){ return 58; }
153  binom_iter(59.f, x, n, emu, cdf); if(cdf >= r || mean <= 26){ return 59; }
154  binom_iter(60.f, x, n, emu, cdf); if(cdf >= r){ return 60; }
155  binom_iter(61.f, x, n, emu, cdf); if(cdf >= r || mean <= 27){ return 61; }
156  binom_iter(62.f, x, n, emu, cdf); if(cdf >= r || mean <= 28){ return 62; }
157  binom_iter(63.f, x, n, emu, cdf); if(cdf >= r){ return 63; }
158  binom_iter(64.f, x, n, emu, cdf); if(cdf >= r || mean <= 29){ return 64; }
159  binom_iter(65.f, x, n, emu, cdf); if(cdf >= r || mean <= 30){ return 65; }
160  binom_iter(66.f, x, n, emu, cdf); if(cdf >= r){ return 66; }
161  binom_iter(67.f, x, n, emu, cdf); if(cdf >= r || mean <= 31){ return 67; }
162  binom_iter(68.f, x, n, emu, cdf); if(cdf >= r || mean <= 32){ return 68; }
163  binom_iter(69.f, x, n, emu, cdf); if(cdf >= r){ return 69; }
164 
165  return 70; //17 for mean <= 3, 24 limit for mean <= 6, 32 limit for mean <= 10, 36 limit for mean <= 12, 41 limit for mean <= 15, 58 limit for mean <= 25, 70 limit for mean <= 33; max float between 0 and 1 is 0.99999999
166 }
167 
168 __host__ __device__ __forceinline__ void pois_iter(float j, float mean, float & emu, float & cdf){
169  emu *= mean/j;
170  cdf += emu;
171 }
172 
173 __host__ __device__ __forceinline__ int poiscdfinv(float r, float mean){
174  float emu = expf(-1 * mean);
175  float cdf = emu;
176  if(cdf >= r){ return 0; }
177 
178  pois_iter(1.f, mean, emu, cdf); if(cdf >= r){ return 1; }
179  pois_iter(2.f, mean, emu, cdf); if(cdf >= r){ return 2; }
180  pois_iter(3.f, mean, emu, cdf); if(cdf >= r){ return 3; }
181  pois_iter(4.f, mean, emu, cdf); if(cdf >= r){ return 4; }
182  pois_iter(5.f, mean, emu, cdf); if(cdf >= r){ return 5; }
183  pois_iter(6.f, mean, emu, cdf); if(cdf >= r){ return 6; }
184  pois_iter(7.f, mean, emu, cdf); if(cdf >= r){ return 7; }
185  pois_iter(8.f, mean, emu, cdf); if(cdf >= r){ return 8; }
186  pois_iter(9.f, mean, emu, cdf); if(cdf >= r){ return 9; }
187  pois_iter(10.f, mean, emu, cdf); if(cdf >= r){ return 10; }
188  pois_iter(11.f, mean, emu, cdf); if(cdf >= r || mean <= 1){ return 11; }
189  pois_iter(12.f, mean, emu, cdf); if(cdf >= r){ return 12; }
190  pois_iter(13.f, mean, emu, cdf); if(cdf >= r){ return 13; }
191  pois_iter(14.f, mean, emu, cdf); if(cdf >= r || mean <= 2){ return 14; }
192  pois_iter(15.f, mean, emu, cdf); if(cdf >= r){ return 15; }
193  pois_iter(16.f, mean, emu, cdf); if(cdf >= r){ return 16; }
194  pois_iter(17.f, mean, emu, cdf); if(cdf >= r || mean <= 3){ return 17; }
195  pois_iter(18.f, mean, emu, cdf); if(cdf >= r){ return 18; }
196  pois_iter(19.f, mean, emu, cdf); if(cdf >= r){ return 19; }
197  pois_iter(20.f, mean, emu, cdf); if(cdf >= r || mean <= 4){ return 20; }
198  pois_iter(21.f, mean, emu, cdf); if(cdf >= r){ return 21; }
199  pois_iter(22.f, mean, emu, cdf); if(cdf >= r || mean <= 5){ return 22; }
200  pois_iter(23.f, mean, emu, cdf); if(cdf >= r){ return 23; }
201  pois_iter(24.f, mean, emu, cdf); if(cdf >= r || mean <= 6){ return 24; }
202  pois_iter(25.f, mean, emu, cdf); if(cdf >= r){ return 25; }
203  pois_iter(26.f, mean, emu, cdf); if(cdf >= r || mean <= 7){ return 26; }
204  pois_iter(27.f, mean, emu, cdf); if(cdf >= r){ return 27; }
205  pois_iter(28.f, mean, emu, cdf); if(cdf >= r || mean <= 8){ return 28; }
206  pois_iter(29.f, mean, emu, cdf); if(cdf >= r){ return 29; }
207  pois_iter(30.f, mean, emu, cdf); if(cdf >= r || mean <= 9){ return 30; }
208  pois_iter(31.f, mean, emu, cdf); if(cdf >= r){ return 31; }
209  pois_iter(32.f, mean, emu, cdf); if(cdf >= r || mean <= 10){ return 32; }
210  pois_iter(33.f, mean, emu, cdf); if(cdf >= r){ return 33; }
211  pois_iter(34.f, mean, emu, cdf); if(cdf >= r || mean <= 11){ return 34; }
212  pois_iter(35.f, mean, emu, cdf); if(cdf >= r){ return 35; }
213  pois_iter(36.f, mean, emu, cdf); if(cdf >= r || mean <= 12){ return 36; }
214  pois_iter(37.f, mean, emu, cdf); if(cdf >= r){ return 37; }
215  pois_iter(38.f, mean, emu, cdf); if(cdf >= r || mean <= 13){ return 38; }
216  pois_iter(39.f, mean, emu, cdf); if(cdf >= r){ return 39; }
217  pois_iter(40.f, mean, emu, cdf); if(cdf >= r || mean <= 14){ return 40; }
218  pois_iter(41.f, mean, emu, cdf); if(cdf >= r || mean <= 15){ return 41; }
219  pois_iter(42.f, mean, emu, cdf); if(cdf >= r){ return 42; }
220  pois_iter(43.f, mean, emu, cdf); if(cdf >= r || mean <= 16){ return 43; }
221  pois_iter(44.f, mean, emu, cdf); if(cdf >= r){ return 44; }
222  pois_iter(45.f, mean, emu, cdf); if(cdf >= r || mean <= 17){ return 45; }
223  pois_iter(46.f, mean, emu, cdf); if(cdf >= r || mean <= 18){ return 46; }
224  pois_iter(47.f, mean, emu, cdf); if(cdf >= r){ return 47; }
225  pois_iter(48.f, mean, emu, cdf); if(cdf >= r || mean <= 19){ return 48; }
226  pois_iter(49.f, mean, emu, cdf); if(cdf >= r){ return 49; }
227  pois_iter(50.f, mean, emu, cdf); if(cdf >= r || mean <= 20){ return 50; }
228  pois_iter(51.f, mean, emu, cdf); if(cdf >= r || mean <= 21){ return 51; }
229  pois_iter(52.f, mean, emu, cdf); if(cdf >= r){ return 52; }
230  pois_iter(53.f, mean, emu, cdf); if(cdf >= r || mean <= 22){ return 53; }
231  pois_iter(54.f, mean, emu, cdf); if(cdf >= r){ return 54; }
232  pois_iter(55.f, mean, emu, cdf); if(cdf >= r || mean <= 23){ return 55; }
233  pois_iter(56.f, mean, emu, cdf); if(cdf >= r || mean <= 24){ return 56; }
234  pois_iter(57.f, mean, emu, cdf); if(cdf >= r){ return 57; }
235  pois_iter(58.f, mean, emu, cdf); if(cdf >= r || mean <= 25){ return 58; }
236  pois_iter(59.f, mean, emu, cdf); if(cdf >= r || mean <= 26){ return 59; }
237  pois_iter(60.f, mean, emu, cdf); if(cdf >= r){ return 60; }
238  pois_iter(61.f, mean, emu, cdf); if(cdf >= r || mean <= 27){ return 61; }
239  pois_iter(62.f, mean, emu, cdf); if(cdf >= r || mean <= 28){ return 62; }
240  pois_iter(63.f, mean, emu, cdf); if(cdf >= r){ return 63; }
241  pois_iter(64.f, mean, emu, cdf); if(cdf >= r || mean <= 29){ return 64; }
242  pois_iter(65.f, mean, emu, cdf); if(cdf >= r || mean <= 30){ return 65; }
243  pois_iter(66.f, mean, emu, cdf); if(cdf >= r){ return 66; }
244  pois_iter(67.f, mean, emu, cdf); if(cdf >= r || mean <= 31){ return 67; }
245  pois_iter(68.f, mean, emu, cdf); if(cdf >= r || mean <= 32){ return 68; }
246  pois_iter(69.f, mean, emu, cdf); if(cdf >= r){ return 69; }
247 
248  return 70; //17 for mean <= 3, 24 limit for mean <= 6, 32 limit for mean <= 10, 36 limit for mean <= 12, 41 limit for mean <= 15, 58 limit for mean <= 25, 70 limit for mean <= 33; max float between 0 and 1 is 0.99999999
249 }
250 
251 __host__ __device__ __forceinline__ int ApproxRandPois1(float mean, float var, float p, float N, int2 seed, int id, int generation, int population){
252  uint4 i = Philox(seed, id, generation, population, 0);
253  if(mean <= RNG_MEAN_BOUNDARY_NORM){ return poiscdfinv(uint_float_01(i.x), mean); }
254  else if(mean >= N-RNG_MEAN_BOUNDARY_NORM){ return N - poiscdfinv(uint_float_01(i.x), N-mean); } //flip side of poisson, when 1-p is small
255  return round(normcdfinv(uint_float_01(i.x))*sqrtf(var)+mean);
256 }
257 
258 __host__ __device__ __forceinline__ int ApproxRandBinom1(float mean, float var, float p, float N, int2 seed, int id, int generation, int population){
259  uint4 i = Philox(seed, id, generation, population, 0);
260  if(mean <= RNG_MEAN_BOUNDARY_NORM){
261  if(N < RNG_N_BOUNDARY_POIS_BINOM){ return binomcdfinv(uint_float_01(i.x), mean, mean/N, N); } else{ return poiscdfinv(uint_float_01(i.x), mean); }
262  }
263  else if(mean >= N-RNG_MEAN_BOUNDARY_NORM){ //flip side of binomial, when 1-p is small
264  if(N < RNG_N_BOUNDARY_POIS_BINOM){ return N - binomcdfinv(uint_float_01(i.x), N-mean, (N-mean)/N, N); } else{ return N - poiscdfinv(uint_float_01(i.x), N-mean); }
265  }
266  return round(normcdfinv(uint_float_01(i.x))*sqrtf(var)+mean);
267 }
268 
269 //faster on if don't inline on both GPUs!
270 __device__ int ApproxRandBinomHelper(unsigned int i, float mean, float var, float N);
271 
272 __device__ __forceinline__ int4 ApproxRandBinom4(float4 mean, float4 var, float4 p, float N, int2 seed, int id, int generation, int population){
273  uint4 i = Philox(seed, id, generation, population, 0);
274  return make_int4(ApproxRandBinomHelper(i.x, mean.x, var.x, N), ApproxRandBinomHelper(i.y, mean.y, var.y, N), ApproxRandBinomHelper(i.z, mean.z, var.z, N), ApproxRandBinomHelper(i.w, mean.w, var.w, N));
275 }
276 /* ----- end random number generation ----- */
277 
278 } /* ----- end namespace RNG ----- */
279 
280 #endif /* SHARED_CUH_ */
Definition: shared.cuh:43