SHOGUN
v2.0.0
|
00001 /* This program is free software: you can redistribute it and/or modify 00002 * it under the terms of the GNU General Public License as published by 00003 * the Free Software Foundation, either version 3 of the License, or 00004 * (at your option) any later version. 00005 * 00006 * This program is distributed in the hope that it will be useful, 00007 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00008 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00009 * GNU General Public License for more details. 00010 * 00011 * You should have received a copy of the GNU General Public License 00012 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00013 * 00014 * Copyright (C) 2009 - 2012 Jun Liu and Jieping Ye 00015 */ 00016 00017 #include <shogun/lib/slep/tree/general_altra.h> 00018 00019 void general_altra(double *x, double *v, int n, double *G, double *ind, int nodes, double mult) 00020 { 00021 00022 int i, j; 00023 double lambda,twoNorm, ratio; 00024 00025 /* 00026 * test whether the first node is special 00027 */ 00028 if ((int) ind[0]==-1){ 00029 00030 /* 00031 *Recheck whether ind[1] equals to zero 00032 */ 00033 if ((int) ind[1]!=-1){ 00034 printf("\n Error! \n Check ind"); 00035 exit(1); 00036 } 00037 00038 lambda=mult*ind[2]; 00039 00040 for(j=0;j<n;j++){ 00041 if (v[j]>lambda) 00042 x[j]=v[j]-lambda; 00043 else 00044 if (v[j]<-lambda) 00045 x[j]=v[j]+lambda; 00046 else 00047 x[j]=0; 00048 } 00049 00050 i=1; 00051 } 00052 else{ 00053 memcpy(x, v, sizeof(double) * n); 00054 i=0; 00055 } 00056 00057 /* 00058 * sequentially process each node 00059 * 00060 */ 00061 for(;i < nodes; i++){ 00062 /* 00063 * compute the L2 norm of this group 00064 */ 00065 twoNorm=0; 00066 for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++) 00067 twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ]; 00068 twoNorm=sqrt(twoNorm); 00069 00070 lambda=mult*ind[3*i+2]; 00071 if (twoNorm>lambda){ 00072 ratio=(twoNorm-lambda)/twoNorm; 00073 00074 /* 00075 * shrinkage this group by ratio 00076 */ 00077 for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++) 00078 x[(int) G[j]-1 ]*=ratio; 00079 } 00080 else{ 00081 /* 00082 * threshold this group to zero 00083 */ 00084 for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++) 00085 x[(int) G[j]-1 ]=0; 00086 } 00087 } 00088 } 00089 00090 void general_altra_mt(double *X, double *V, int n, int k, double *G, double *ind, int nodes, double mult) 00091 { 00092 int i, j; 00093 00094 double *x=(double *)malloc(sizeof(double)*k); 00095 double *v=(double *)malloc(sizeof(double)*k); 00096 00097 for (i=0;i<n;i++){ 00098 /* 00099 * copy a row of V to v 00100 * 00101 */ 00102 for(j=0;j<k;j++) 00103 v[j]=V[j*n + i]; 00104 00105 general_altra(x, v, k, G, ind, nodes, mult); 00106 00107 /* 00108 * copy the solution to X 00109 */ 00110 for(j=0;j<k;j++) 00111 X[j*n+i]=x[j]; 00112 } 00113 00114 free(x); 00115 free(v); 00116 } 00117 00118 void general_computeLambda2Max(double *lambda2_max, double *x, int n, double *G, double *ind, int nodes) 00119 { 00120 int i, j; 00121 double twoNorm; 00122 00123 *lambda2_max=0; 00124 00125 00126 00127 for(i=0;i < nodes; i++){ 00128 /* 00129 * compute the L2 norm of this group 00130 */ 00131 twoNorm=0; 00132 for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++) 00133 twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ]; 00134 twoNorm=sqrt(twoNorm); 00135 00136 twoNorm=twoNorm/ind[3*i+2]; 00137 00138 if (twoNorm >*lambda2_max ) 00139 *lambda2_max=twoNorm; 00140 } 00141 } 00142 00143 double general_treeNorm(double *x, int ldx, int n, double *G, double *ind, int nodes) 00144 { 00145 00146 int i, j; 00147 double twoNorm, lambda; 00148 00149 double tree_norm=0; 00150 00151 /* 00152 * test whether the first node is special 00153 */ 00154 if ((int) ind[0]==-1){ 00155 00156 /* 00157 *Recheck whether ind[1] equals to zero 00158 */ 00159 if ((int) ind[1]!=-1){ 00160 printf("\n Error! \n Check ind"); 00161 exit(1); 00162 } 00163 00164 lambda=ind[2]; 00165 00166 for(j=0;j<n;j+=ldx){ 00167 tree_norm+=fabs(x[j]); 00168 } 00169 00170 tree_norm=tree_norm * lambda; 00171 00172 i=1; 00173 } 00174 else{ 00175 i=0; 00176 } 00177 00178 /* 00179 * sequentially process each node 00180 * 00181 */ 00182 for(;i < nodes; i++){ 00183 /* 00184 * compute the L2 norm of this group 00185 00186 */ 00187 twoNorm=0; 00188 for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++) 00189 twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ]; 00190 twoNorm=sqrt(twoNorm); 00191 00192 lambda=ind[3*i+2]; 00193 00194 tree_norm=tree_norm + lambda*twoNorm; 00195 } 00196 return tree_norm; 00197 } 00198 00199 double general_findLambdaMax(double *v, int n, double *G, double *ind, int nodes) 00200 { 00201 00202 int i; 00203 double lambda=0,squaredWeight=0, lambda1,lambda2; 00204 double *x=(double *)malloc(sizeof(double)*n); 00205 double *ind2=(double *)malloc(sizeof(double)*nodes*3); 00206 int num=0; 00207 00208 for(i=0;i<n;i++){ 00209 lambda+=v[i]*v[i]; 00210 } 00211 00212 if ( (int)ind[0]==-1 ) 00213 squaredWeight=n*ind[2]*ind[2]; 00214 else 00215 squaredWeight=ind[2]*ind[2]; 00216 00217 for (i=1;i<nodes;i++){ 00218 squaredWeight+=ind[3*i+2]*ind[3*i+2]; 00219 } 00220 00221 /* set lambda to an initial guess 00222 */ 00223 lambda=sqrt(lambda/squaredWeight); 00224 00225 /* 00226 printf("\n\n lambda=%2.5f",lambda); 00227 */ 00228 00229 /* 00230 *copy ind to ind2, 00231 *and scale the weight 3*i+2 00232 */ 00233 for(i=0;i<nodes;i++){ 00234 ind2[3*i]=ind[3*i]; 00235 ind2[3*i+1]=ind[3*i+1]; 00236 ind2[3*i+2]=ind[3*i+2]*lambda; 00237 } 00238 00239 /* test whether the solution is zero or not 00240 */ 00241 general_altra(x, v, n, G, ind2, nodes); 00242 for(i=0;i<n;i++){ 00243 if (x[i]!=0) 00244 break; 00245 } 00246 00247 if (i>=n) { 00248 /*x is a zero vector*/ 00249 lambda2=lambda; 00250 lambda1=lambda; 00251 00252 num=0; 00253 00254 while(1){ 00255 num++; 00256 00257 lambda2=lambda; 00258 lambda1=lambda1/2; 00259 /* update ind2 00260 */ 00261 for(i=0;i<nodes;i++){ 00262 ind2[3*i+2]=ind[3*i+2]*lambda1; 00263 } 00264 00265 /* compute and test whether x is zero 00266 */ 00267 general_altra(x, v, n, G, ind2, nodes); 00268 for(i=0;i<n;i++){ 00269 if (x[i]!=0) 00270 break; 00271 } 00272 00273 if (i<n){ 00274 break; 00275 /*x is not zero 00276 *we have found lambda1 00277 */ 00278 } 00279 } 00280 00281 } 00282 else{ 00283 /*x is a non-zero vector*/ 00284 lambda2=lambda; 00285 lambda1=lambda; 00286 00287 num=0; 00288 while(1){ 00289 num++; 00290 00291 lambda1=lambda2; 00292 lambda2=lambda2*2; 00293 /* update ind2 00294 */ 00295 for(i=0;i<nodes;i++){ 00296 ind2[3*i+2]=ind[3*i+2]*lambda2; 00297 } 00298 00299 /* compute and test whether x is zero 00300 */ 00301 general_altra(x, v, n, G, ind2, nodes); 00302 for(i=0;i<n;i++){ 00303 if (x[i]!=0) 00304 break; 00305 } 00306 00307 if (i>=n){ 00308 break; 00309 /*x is a zero vector 00310 *we have found lambda2 00311 */ 00312 } 00313 } 00314 } 00315 00316 /* 00317 printf("\n num=%d, lambda1=%2.5f, lambda2=%2.5f",num, lambda1,lambda2); 00318 */ 00319 00320 while ( fabs(lambda2-lambda1) > lambda2 * 1e-10 ){ 00321 00322 num++; 00323 00324 lambda=(lambda1+lambda2)/2; 00325 00326 /* update ind2 00327 */ 00328 for(i=0;i<nodes;i++){ 00329 ind2[3*i+2]=ind[3*i+2]*lambda; 00330 } 00331 00332 /* compute and test whether x is zero 00333 */ 00334 general_altra(x, v, n, G, ind2, nodes); 00335 for(i=0;i<n;i++){ 00336 if (x[i]!=0) 00337 break; 00338 } 00339 00340 if (i>=n){ 00341 lambda2=lambda; 00342 } 00343 else{ 00344 lambda1=lambda; 00345 } 00346 00347 /* 00348 printf("\n lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2); 00349 */ 00350 } 00351 00352 /* 00353 printf("\n num=%d",num); 00354 00355 printf(" lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2); 00356 */ 00357 00358 free(x); 00359 free(ind2); 00360 00361 return lambda2; 00362 } 00363 00364 double general_findLambdaMax_mt(double *V, int n, int k, double *G, double *ind, int nodes) 00365 { 00366 int i, j; 00367 00368 double *v=(double *)malloc(sizeof(double)*k); 00369 double lambda; 00370 00371 double lambdaMax=0; 00372 00373 for (i=0;i<n;i++){ 00374 /* 00375 * copy a row of V to v 00376 * 00377 */ 00378 for(j=0;j<k;j++) 00379 v[j]=V[j*n + i]; 00380 00381 lambda = general_findLambdaMax(v, k, G, ind, nodes); 00382 00383 /* 00384 printf("\n lambda=%5.2f",lambda); 00385 */ 00386 00387 00388 if (lambda>lambdaMax) 00389 lambdaMax=lambda; 00390 } 00391 00392 /* 00393 printf("\n *lambdaMax=%5.2f",*lambdaMax); 00394 */ 00395 00396 free(v); 00397 return lambdaMax; 00398 }