00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00028 #include <stdlib.h>
00029 #include <math.h>
00030 #include <complex.h>
00031
00032 #include "nfft3util.h"
00033 #include "nfft3.h"
00034 #include "fastsum.h"
00035
00042 double fak(int n)
00043 {
00044 if (n<=1) return 1.0;
00045 else return (double)n*fak(n-1);
00046 }
00047
00049 double binom(int n, int m)
00050 {
00051 return fak(n)/fak(m)/fak(n-m);
00052 }
00053
00055 double BasisPoly(int m, int r, double xx)
00056 {
00057 int k;
00058 double sum=0.0;
00059
00060 for (k=0; k<=m-r; k++) {
00061 sum+=binom(m+k,k)*pow((xx+1.0)/2.0,(double)k);
00062 }
00063 return sum*pow((xx+1.0),(double)r)*pow(1.0-xx,(double)(m+1))/(1<<(m+1))/fak(r);
00064 }
00065
00067 double _Complex regkern(kernel k, double xx, int p, const double *param, double a, double b)
00068 {
00069 int r;
00070 double _Complex sum=0.0;
00071
00072 if (xx<-0.5)
00073 xx=-0.5;
00074 if (xx>0.5)
00075 xx=0.5;
00076 if ((xx>=-0.5+b && xx<=-a) || (xx>=a && xx<=0.5-b)) {
00077 return k(xx,0,param);
00078 }
00079 else if (xx<-0.5+b) {
00080 sum=(k(-0.5,0,param)+k(0.5,0,param))/2.0
00081 *BasisPoly(p-1,0,2.0*xx/b+(1.0-b)/b);
00082 for (r=0; r<p; r++) {
00083 sum+=pow(-b/2.0,(double)r)
00084 *k(-0.5+b,r,param)
00085 *BasisPoly(p-1,r,-2.0*xx/b+(b-1.0)/b);
00086 }
00087 return sum;
00088 }
00089 else if ((xx>-a) && (xx<a)) {
00090 for (r=0; r<p; r++) {
00091 sum+=pow(a,(double)r)
00092 *( k(-a,r,param)*BasisPoly(p-1,r,xx/a)
00093 +k( a,r,param)*BasisPoly(p-1,r,-xx/a)*(r & 1 ? -1 : 1));
00094 }
00095 return sum;
00096 }
00097 else if (xx>0.5-b) {
00098 sum=(k(-0.5,0,param)+k(0.5,0,param))/2.0
00099 *BasisPoly(p-1,0,-2.0*xx/b+(1.0-b)/b);
00100 for (r=0; r<p; r++) {
00101 sum+=pow(b/2.0,(double)r)
00102 *k(0.5-b,r,param)
00103 *BasisPoly(p-1,r,2.0*xx/b-(1.0-b)/b);
00104 }
00105 return sum;
00106 }
00107 return k(xx,0,param);
00108 }
00109
00113 double _Complex regkern1(kernel k, double xx, int p, const double *param, double a, double b)
00114 {
00115 int r;
00116 double _Complex sum=0.0;
00117
00118 if (xx<-0.5)
00119 xx=-0.5;
00120 if (xx>0.5)
00121 xx=0.5;
00122 if ((xx>=-0.5+b && xx<=-a) || (xx>=a && xx<=0.5-b))
00123 {
00124 return k(xx,0,param);
00125 }
00126 else if ((xx>-a) && (xx<a))
00127 {
00128 for (r=0; r<p; r++) {
00129 sum+=pow(a,(double)r)
00130 *( k(-a,r,param)*BasisPoly(p-1,r,xx/a)
00131 +k( a,r,param)*BasisPoly(p-1,r,-xx/a)*(r & 1 ? -1 : 1));
00132 }
00133 return sum;
00134 }
00135 else if (xx<-0.5+b)
00136 {
00137 for (r=0; r<p; r++) {
00138 sum+=pow(b,(double)r)
00139 *( k(0.5-b,r,param)*BasisPoly(p-1,r,(xx+0.5)/b)
00140 +k(-0.5+b,r,param)*BasisPoly(p-1,r,-(xx+0.5)/b)*(r & 1 ? -1 : 1));
00141 }
00142 return sum;
00143 }
00144 else if (xx>0.5-b)
00145 {
00146 for (r=0; r<p; r++) {
00147 sum+=pow(b,(double)r)
00148 *( k(0.5-b,r,param)*BasisPoly(p-1,r,(xx-0.5)/b)
00149 +k(-0.5+b,r,param)*BasisPoly(p-1,r,-(xx-0.5)/b)*(r & 1 ? -1 : 1));
00150 }
00151 return sum;
00152 }
00153 return k(xx,0,param);
00154 }
00155
00157 double _Complex regkern2(kernel k, double xx, int p, const double *param, double a, double b)
00158 {
00159 int r;
00160 double _Complex sum=0.0;
00161
00162 xx=fabs(xx);
00163
00164 if (xx>0.5) {
00165 for (r=0; r<p; r++) {
00166 sum+=pow(b,(double)r)*k(0.5-b,r,param)
00167 *(BasisPoly(p-1,r,0)+BasisPoly(p-1,r,0));
00168 }
00169 return sum;
00170 }
00171 else if ((a<=xx) && (xx<=0.5-b)) {
00172 return k(xx,0,param);
00173 }
00174 else if (xx<a) {
00175 for (r=0; r<p; r++) {
00176 sum+=pow(-a,(double)r)*k(a,r,param)
00177 *(BasisPoly(p-1,r,xx/a)+BasisPoly(p-1,r,-xx/a));
00178 }
00179 return sum;
00180 }
00181 else if ((0.5-b<xx) && (xx<=0.5)) {
00182 for (r=0; r<p; r++) {
00183 sum+=pow(b,(double)r)*k(0.5-b,r,param)
00184 *(BasisPoly(p-1,r,(xx-0.5)/b)+BasisPoly(p-1,r,-(xx-0.5)/b));
00185 }
00186 return sum;
00187 }
00188 return 0.0;
00189 }
00190
00194 double _Complex regkern3(kernel k, double xx, int p, const double *param, double a, double b)
00195 {
00196 int r;
00197 double _Complex sum=0.0;
00198
00199 xx=fabs(xx);
00200
00201 if (xx>=0.5) {
00202
00203 xx=0.5;
00204 }
00205
00206 if ((a<=xx) && (xx<=0.5-b)) {
00207 return k(xx,0,param);
00208 }
00209 else if (xx<a) {
00210 for (r=0; r<p; r++) {
00211 sum+=pow(-a,(double)r)*k(a,r,param)
00212 *(BasisPoly(p-1,r,xx/a)+BasisPoly(p-1,r,-xx/a));
00213 }
00214
00215 return sum;
00216 }
00217 else if ((0.5-b<xx) && (xx<=0.5)) {
00218 sum=k(0.5,0,param)*BasisPoly(p-1,0,-2.0*xx/b+(1.0-b)/b);
00219
00220 for (r=0; r<p; r++) {
00221 sum+=pow(b/2.0,(double)r)
00222 *k(0.5-b,r,param)
00223 *BasisPoly(p-1,r,2.0*xx/b-(1.0-b)/b);
00224 }
00225 return sum;
00226 }
00227 return 0.0;
00228 }
00229
00231 double _Complex kubintkern(const double x, const double _Complex *Add,
00232 const int Ad, const double a)
00233 {
00234 double c,c1,c2,c3,c4;
00235 int r;
00236 double _Complex f0,f1,f2,f3;
00237 c=x*Ad/a;
00238 r=c; r=abs(r);
00239 if (r==0) {f0=Add[r+1];f1=Add[r];f2=Add[r+1];f3=Add[r+2];}
00240 else { f0=Add[r-1];f1=Add[r];f2=Add[r+1];f3=Add[r+2];}
00241 c=fabs(c);
00242 c1=c-r;
00243 c2=c1+1.0;
00244 c3=c1-1.0;
00245 c4=c1-2.0;
00246
00247
00248 return(-f0*c1*c3*c4/6.0+f1*c2*c3*c4/2.0-f2*c2*c1*c4/2.0+f3*c2*c1*c3/6.0);
00249 }
00250
00252 double _Complex kubintkern1(const double x, const double _Complex *Add,
00253 const int Ad, const double a)
00254 {
00255 double c,c1,c2,c3,c4;
00256 int r;
00257 double _Complex f0,f1,f2,f3;
00258 Add+=2;
00259 c=(x+a)*Ad/2/a;
00260 r=c; r=abs(r);
00261
00262
00263 { f0=Add[r-1];f1=Add[r];f2=Add[r+1];f3=Add[r+2];}
00264 c=fabs(c);
00265 c1=c-r;
00266 c2=c1+1.0;
00267 c3=c1-1.0;
00268 c4=c1-2.0;
00269
00270
00271 return(-f0*c1*c3*c4/6.0+f1*c2*c3*c4/2.0-f2*c2*c1*c4/2.0+f3*c2*c1*c3/6.0);
00272 }
00273
00275 void quicksort(int d, int t, double *x, double _Complex *alpha, int N)
00276 {
00277 int lpos=0;
00278 int rpos=N-1;
00279
00280 double pivot=x[(N/2)*d+t];
00281
00282 int k;
00283 double temp1;
00284 double _Complex temp2;
00285
00286 while (lpos<=rpos)
00287 {
00288 while (x[lpos*d+t]<pivot)
00289 lpos++;
00290 while (x[rpos*d+t]>pivot)
00291 rpos--;
00292 if (lpos<=rpos)
00293 {
00294 for (k=0; k<d; k++)
00295 {
00296 temp1=x[lpos*d+k];
00297 x[lpos*d+k]=x[rpos*d+k];
00298 x[rpos*d+k]=temp1;
00299 }
00300 temp2=alpha[lpos];
00301 alpha[lpos]=alpha[rpos];
00302 alpha[rpos]=temp2;
00303
00304 lpos++;
00305 rpos--;
00306 }
00307 }
00308 if (0<rpos)
00309 quicksort(d,t,x,alpha,rpos+1);
00310 if (lpos<N-1)
00311 quicksort(d,t,x+lpos*d,alpha+lpos,N-lpos);
00312 }
00313
00315 void BuildTree(int d, int t, double *x, double _Complex *alpha, int N)
00316 {
00317 if (N>1)
00318 {
00319 int m=N/2;
00320
00321 quicksort(d,t,x,alpha,N);
00322
00323 BuildTree(d, (t+1)%d, x, alpha, m);
00324 BuildTree(d, (t+1)%d, x+(m+1)*d, alpha+(m+1), N-m-1);
00325 }
00326 }
00327
00329 double _Complex SearchTree(const int d, const int t, const double *x,
00330 const double _Complex *alpha, const double *xmin, const double *xmax,
00331 const int N, const kernel k, const double *param, const int Ad,
00332 const double _Complex *Add, const int p, const unsigned flags)
00333 {
00334 int m=N/2;
00335 double Min=xmin[t], Max=xmax[t], Median=x[m*d+t];
00336 double a=fabs(Max-Min)/2;
00337 int l;
00338 int E=0;
00339 double r;
00340
00341 if (N==0)
00342 return 0.0;
00343 if (Min>Median)
00344 return SearchTree(d,(t+1)%d,x+(m+1)*d,alpha+(m+1),xmin,xmax,N-m-1,k,param,Ad,Add,p,flags);
00345 else if (Max<Median)
00346 return SearchTree(d,(t+1)%d,x,alpha,xmin,xmax,m,k,param,Ad,Add,p,flags);
00347 else
00348 {
00349 double _Complex result = 0.0;
00350 E=0;
00351
00352 for (l=0; l<d; l++)
00353 {
00354 if ( x[m*d+l]>xmin[l] && x[m*d+l]<xmax[l] )
00355 E++;
00356 }
00357
00358 if (E==d)
00359 {
00360 if (d==1)
00361 {
00362 r = xmin[0]+a-x[m];
00363 }
00364 else
00365 {
00366 r=0.0;
00367 for (l=0; l<d; l++)
00368 r+=(xmin[l]+a-x[m*d+l])*(xmin[l]+a-x[m*d+l]);
00369 r=sqrt(r);
00370 }
00371 if (fabs(r)<a)
00372 {
00373 result += alpha[m]*k(r,0,param);
00374 if (d==1)
00375 {
00376 if (flags & EXACT_NEARFIELD)
00377 result -= alpha[m]*regkern1(k,r,p,param,a,1.0/16.0);
00378 else
00379 result -= alpha[m]*kubintkern1(r,Add,Ad,a);
00380 }
00381 else
00382 {
00383 if (flags & EXACT_NEARFIELD)
00384 result -= alpha[m]*regkern(k,r,p,param,a,1.0/16.0);
00385 else
00386 result -= alpha[m]*kubintkern(r,Add,Ad,a);
00387 }
00388 }
00389 }
00390 result += SearchTree(d,(t+1)%d,x+(m+1)*d,alpha+(m+1),xmin,xmax,N-m-1,k,param,Ad,Add,p,flags)
00391 + SearchTree(d,(t+1)%d,x,alpha,xmin,xmax,m,k,param,Ad,Add,p,flags);
00392 return result;
00393 }
00394 }
00395
00397 void fastsum_init_guru(fastsum_plan *ths, int d, int N_total, int M_total, kernel k, double *param, unsigned flags, int nn, int m, int p, double eps_I, double eps_B)
00398 {
00399 int t;
00400 int N[d], n[d];
00401 int n_total;
00402
00403 ths->d = d;
00404
00405 ths->N_total = N_total;
00406 ths->M_total = M_total;
00407
00408 ths->x = (double *)nfft_malloc(d*N_total*(sizeof(double)));
00409 ths->alpha = (double _Complex *)nfft_malloc(N_total*(sizeof(double _Complex)));
00410
00411 ths->y = (double *)nfft_malloc(d*M_total*(sizeof(double)));
00412 ths->f = (double _Complex *)nfft_malloc(M_total*(sizeof(double _Complex)));
00413
00414 ths->k = k;
00415 ths->kernel_param = param;
00416
00417 ths->flags = flags;
00418
00419 ths->p = p;
00420 ths->eps_I = eps_I;
00421 ths->eps_B = eps_B;
00424 if (!(ths->flags & EXACT_NEARFIELD))
00425 {
00426 if (ths->d==1)
00427 {
00428 ths->Ad = 4*(ths->p)*(ths->p);
00429 ths->Add = (double _Complex *)nfft_malloc((ths->Ad+5)*(sizeof(double _Complex)));
00430 }
00431 else
00432 {
00433 ths->Ad = 2*(ths->p)*(ths->p);
00434 ths->Add = (double _Complex *)nfft_malloc((ths->Ad+3)*(sizeof(double _Complex)));
00435 }
00436 }
00437
00439 ths->n = nn;
00440 for (t=0; t<d; t++)
00441 {
00442 N[t] = nn;
00443 n[t] = 2*nn;
00444 }
00445 nfft_init_guru(&(ths->mv1), d, N, N_total, n, m,
00446 PRE_PHI_HUT| PRE_PSI| MALLOC_X | MALLOC_F_HAT| MALLOC_F| FFTW_INIT | FFT_OUT_OF_PLACE,
00447 FFTW_MEASURE| FFTW_DESTROY_INPUT);
00448 nfft_init_guru(&(ths->mv2), d, N, M_total, n, m,
00449 PRE_PHI_HUT| PRE_PSI| MALLOC_X | MALLOC_F_HAT| MALLOC_F| FFTW_INIT | FFT_OUT_OF_PLACE,
00450 FFTW_MEASURE| FFTW_DESTROY_INPUT);
00451
00453 n_total = 1;
00454 for (t=0; t<d; t++)
00455 n_total *= nn;
00456
00457 ths->b = (fftw_complex *)nfft_malloc(n_total*sizeof(fftw_complex));
00458 ths->fft_plan = fftw_plan_dft(d,N,ths->b,ths->b,FFTW_FORWARD,FFTW_ESTIMATE);
00459
00460 }
00461
00463 void fastsum_finalize(fastsum_plan *ths)
00464 {
00465 nfft_free(ths->x);
00466 nfft_free(ths->alpha);
00467 nfft_free(ths->y);
00468 nfft_free(ths->f);
00469
00470 if (!(ths->flags & EXACT_NEARFIELD))
00471 nfft_free(ths->Add);
00472
00473 nfft_finalize(&(ths->mv1));
00474 nfft_finalize(&(ths->mv2));
00475
00476 fftw_destroy_plan(ths->fft_plan);
00477 nfft_free(ths->b);
00478 }
00479
00481 void fastsum_exact(fastsum_plan *ths)
00482 {
00483 int j,k;
00484 int t;
00485 double r;
00486
00487 for (j=0; j<ths->M_total; j++)
00488 {
00489 ths->f[j]=0.0;
00490 for (k=0; k<ths->N_total; k++)
00491 {
00492 if (ths->d==1)
00493 r = ths->y[j] - ths->x[k];
00494 else
00495 {
00496 r=0.0;
00497 for (t=0; t<ths->d; t++)
00498 r += (ths->y[j*ths->d+t]-ths->x[k*ths->d+t])*(ths->y[j*ths->d+t]-ths->x[k*ths->d+t]);
00499 r=sqrt(r);
00500 }
00501 ths->f[j] += ths->alpha[k] * ths->k(r,0,ths->kernel_param);
00502 }
00503 }
00504 }
00505
00507 void fastsum_precompute(fastsum_plan *ths)
00508 {
00509 int j,k,t;
00510 int n_total;
00511
00513 BuildTree(ths->d,0,ths->x,ths->alpha,ths->N_total);
00514
00516 if (!(ths->flags & EXACT_NEARFIELD))
00517 {
00518 if (ths->d==1)
00519 for (k=-ths->Ad/2-2; k <= ths->Ad/2+2; k++)
00520 ths->Add[k+ths->Ad/2+2] = regkern1(ths->k, ths->eps_I*(double)k/ths->Ad*2, ths->p, ths->kernel_param, ths->eps_I, ths->eps_B);
00521 else
00522 for (k=0; k <= ths->Ad+2; k++)
00523 ths->Add[k] = regkern3(ths->k, ths->eps_I*(double)k/ths->Ad, ths->p, ths->kernel_param, ths->eps_I, ths->eps_B);
00524 }
00525
00527 for (k=0; k<ths->mv1.M_total; k++)
00528 for (t=0; t<ths->mv1.d; t++)
00529 ths->mv1.x[ths->mv1.d*k+t] = - ths->x[ths->mv1.d*k+t];
00530
00532 if(ths->mv1.nfft_flags & PRE_LIN_PSI)
00533 nfft_precompute_lin_psi(&(ths->mv1));
00534
00535 if(ths->mv1.nfft_flags & PRE_PSI)
00536 nfft_precompute_psi(&(ths->mv1));
00537
00538 if(ths->mv1.nfft_flags & PRE_FULL_PSI)
00539 nfft_precompute_full_psi(&(ths->mv1));
00540
00542 for(k=0; k<ths->mv1.M_total;k++)
00543 ths->mv1.f[k] = ths->alpha[k];
00544
00546 for (j=0; j<ths->mv2.M_total; j++)
00547 for (t=0; t<ths->mv2.d; t++)
00548 ths->mv2.x[ths->mv2.d*j+t] = - ths->y[ths->mv2.d*j+t];
00549
00551 if(ths->mv2.nfft_flags & PRE_LIN_PSI)
00552 nfft_precompute_lin_psi(&(ths->mv2));
00553
00554 if(ths->mv2.nfft_flags & PRE_PSI)
00555 nfft_precompute_psi(&(ths->mv2));
00556
00557 if(ths->mv2.nfft_flags & PRE_FULL_PSI)
00558 nfft_precompute_full_psi(&(ths->mv2));
00559
00560
00562 n_total = 1;
00563 for (t=0; t<ths->d; t++)
00564 n_total *= ths->n;
00565
00566 for (j=0; j<n_total; j++)
00567 {
00568 if (ths->d==1)
00569 ths->b[j] = regkern1(ths->k, (double)j / (ths->n) - 0.5, ths->p, ths->kernel_param, ths->eps_I, ths->eps_B)/n_total;
00570 else
00571 {
00572 k=j;
00573 ths->b[j]=0.0;
00574 for (t=0; t<ths->d; t++)
00575 {
00576 ths->b[j] += ((double)(k % (ths->n)) / (ths->n) - 0.5) * ((double)(k % (ths->n)) / (ths->n) - 0.5);
00577 k = k / (ths->n);
00578 }
00579 ths->b[j] = regkern3(ths->k, sqrt(ths->b[j]), ths->p, ths->kernel_param, ths->eps_I, ths->eps_B)/n_total;
00580 }
00581 }
00582
00583 nfft_fftshift_complex(ths->b, ths->mv1.d, ths->mv1.N);
00584 fftw_execute(ths->fft_plan);
00585 nfft_fftshift_complex(ths->b, ths->mv1.d, ths->mv1.N);
00586
00587 }
00588
00590 void fastsum_trafo(fastsum_plan *ths)
00591 {
00592 int j,k,t;
00593 double *ymin, *ymax;
00595 ymin = (double *)nfft_malloc(ths->d*(sizeof(double)));
00596 ymax = (double *)nfft_malloc(ths->d*(sizeof(double)));
00597
00599 nfft_adjoint(&(ths->mv1));
00600
00602 for (k=0; k<ths->mv2.N_total; k++)
00603 ths->mv2.f_hat[k] = ths->b[k] * ths->mv1.f_hat[k];
00604
00606 nfft_trafo(&(ths->mv2));
00607
00609 for (j=0; j<ths->M_total; j++)
00610 {
00611 for (t=0; t<ths->d; t++)
00612 {
00613 ymin[t] = ths->y[ths->d*j+t] - ths->eps_I;
00614 ymax[t] = ths->y[ths->d*j+t] + ths->eps_I;
00615 }
00616 ths->f[j] = ths->mv2.f[j] + SearchTree(ths->d,0, ths->x, ths->alpha, ymin, ymax, ths->N_total, ths->k, ths->kernel_param, ths->Ad, ths->Add, ths->p, ths->flags);
00617
00618
00619 }
00620 }
00621
00622
00623