altra.cpp

Go to the documentation of this file.
00001 /*   This program is free software: you can redistribute it and/or modify
00002  *   it under the terms of the GNU General Public License as published by
00003  *   the Free Software Foundation, either version 3 of the License, or
00004  *   (at your option) any later version.
00005  *
00006  *   This program is distributed in the hope that it will be useful,
00007  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
00008  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00009  *   GNU General Public License for more details.
00010  *
00011  *   You should have received a copy of the GNU General Public License
00012  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
00013  *
00014  *   Copyright (C) 2009 - 2012 Jun Liu and Jieping Ye 
00015  */
00016 
00017 #include <shogun/lib/slep/tree/altra.h>
00018 
00019 void altra(double *x, double *v, int n, double *ind, int nodes, double mult)
00020 {
00021     int i, j;
00022     double lambda,twoNorm, ratio;
00023 
00024     /*
00025      * test whether the first node is special
00026      */
00027     if ((int) ind[0]==-1){
00028 
00029         /*
00030          *Recheck whether ind[1] equals to zero
00031          */
00032         if ((int) ind[1]!=-1){
00033             printf("\n Error! \n Check ind");
00034             exit(1);
00035         }        
00036 
00037         lambda=mult*ind[2];
00038 
00039         for(j=0;j<n;j++){
00040             if (v[j]>lambda)
00041                 x[j]=v[j]-lambda;
00042             else
00043                 if (v[j]<-lambda)
00044                     x[j]=v[j]+lambda;
00045                 else
00046                     x[j]=0;
00047         }
00048 
00049         i=1;
00050     }
00051     else{
00052         memcpy(x, v, sizeof(double) * n);
00053         i=0;
00054     }
00055 
00056     /*
00057      * sequentially process each node
00058      *
00059      */
00060     for(;i < nodes; i++){
00061         /*
00062          * compute the L2 norm of this group         
00063          */
00064         twoNorm=0;
00065         for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
00066             twoNorm += x[j] * x[j];        
00067         twoNorm=sqrt(twoNorm);
00068 
00069         lambda=mult*ind[3*i+2];
00070         if (twoNorm>lambda){
00071             ratio=(twoNorm-lambda)/twoNorm;
00072 
00073             /*
00074              * shrinkage this group by ratio
00075              */
00076             for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
00077                 x[j]*=ratio;            
00078         }
00079         else{
00080             /*
00081              * threshold this group to zero
00082              */
00083             for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
00084                 x[j]=0;
00085         }
00086     }
00087 }
00088 
00089 void altra_mt(double *X, double *V, int n, int k, double *ind, int nodes, double mult)
00090 {
00091     int i, j;
00092 
00093     double *x=(double *)malloc(sizeof(double)*k);
00094     double *v=(double *)malloc(sizeof(double)*k);
00095 
00096     for (i=0;i<n;i++){
00097         /*
00098          * copy a row of V to v
00099          *         
00100          */
00101         for(j=0;j<k;j++)
00102             v[j]=V[j*n + i];
00103 
00104         altra(x, v, k, ind, nodes, mult);
00105 
00106         /*
00107          * copy the solution to X         
00108          */        
00109         for(j=0;j<k;j++)
00110             X[j*n+i]=x[j];
00111     }
00112 
00113     free(x);
00114     free(v);
00115 }
00116 
00117 void computeLambda2Max(double *lambda2_max, double *x, int n, double *ind, int nodes)
00118 {
00119     int i, j;
00120     double twoNorm;
00121 
00122     *lambda2_max=0;
00123 
00124     for(i=0;i < nodes; i++){
00125         /*
00126          * compute the L2 norm of this group         
00127          */
00128         twoNorm=0;
00129         for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
00130             twoNorm += x[j] * x[j];        
00131         twoNorm=sqrt(twoNorm);
00132 
00133         twoNorm=twoNorm/ind[3*i+2];
00134 
00135         if (twoNorm >*lambda2_max )
00136             *lambda2_max=twoNorm;        
00137     }
00138 }
00139 
00140 double treeNorm(double *x, int ldx, int n, double *ind, int nodes){
00141 
00142     int i, j;
00143     double twoNorm, lambda;
00144 
00145     double tree_norm = 0;
00146 
00147     /*
00148      * test whether the first node is special
00149      */
00150     if ((int) ind[0]==-1){
00151 
00152         /*
00153          *Recheck whether ind[1] equals to zero
00154          */
00155         if ((int) ind[1]!=-1){
00156             printf("\n Error! \n Check ind");
00157             exit(1);
00158         }        
00159 
00160         lambda=ind[2];
00161 
00162         for(j=0;j<n*ldx;j+=ldx){
00163             tree_norm+=fabs(x[j]);
00164         }
00165 
00166         tree_norm = tree_norm * lambda;
00167 
00168         i=1;
00169     }
00170     else{
00171         i=0;
00172     }
00173 
00174     /*
00175      * sequentially process each node
00176      *
00177      */
00178     for(;i < nodes; i++){
00179         /*
00180          * compute the L2 norm of this group         
00181          */
00182         twoNorm=0;
00183 
00184         int n_in_node = (int) ind[3*i+1] - (int) ind[3*i]-1;
00185         for(j=(int) ind[3*i]-1;j< (int) ind[3*i]-1 + n_in_node*ldx;j+=ldx)
00186             twoNorm += x[j] * x[j];        
00187         twoNorm=sqrt(twoNorm);
00188 
00189         lambda=ind[3*i+2];
00190 
00191         tree_norm = tree_norm + lambda*twoNorm;
00192     }
00193 
00194     return tree_norm;
00195 }
00196 
00197 double findLambdaMax(double *v, int n, double *ind, int nodes){
00198 
00199     int i;
00200     double lambda=0,squaredWeight=0, lambda1,lambda2;
00201     double *x=(double *)malloc(sizeof(double)*n);
00202     double *ind2=(double *)malloc(sizeof(double)*nodes*3);
00203     int num=0;
00204 
00205     for(i=0;i<n;i++){
00206         lambda+=v[i]*v[i];
00207     }
00208 
00209     if ( (int)ind[0]==-1 )
00210         squaredWeight=n*ind[2]*ind[2];
00211     else
00212         squaredWeight=ind[2]*ind[2];
00213 
00214     for (i=1;i<nodes;i++){
00215         squaredWeight+=ind[3*i+2]*ind[3*i+2];
00216     }
00217 
00218     /* set lambda to an initial guess
00219     */
00220     lambda=sqrt(lambda/squaredWeight);
00221 
00222     /*
00223        printf("\n\n   lambda=%2.5f",lambda);
00224        */
00225 
00226     /*
00227      *copy ind to ind2,
00228      *and scale the weight 3*i+2
00229      */
00230     for(i=0;i<nodes;i++){
00231         ind2[3*i]=ind[3*i];
00232         ind2[3*i+1]=ind[3*i+1];
00233         ind2[3*i+2]=ind[3*i+2]*lambda;
00234     }
00235 
00236     /* test whether the solution is zero or not
00237     */
00238     altra(x, v, n, ind2, nodes);    
00239     for(i=0;i<n;i++){
00240         if (x[i]!=0)
00241             break;
00242     }
00243 
00244     if (i>=n) {
00245         /*x is a zero vector*/
00246         lambda2=lambda;
00247         lambda1=lambda;
00248 
00249         num=0;
00250 
00251         while(1){
00252             num++;
00253 
00254             lambda2=lambda;
00255             lambda1=lambda1/2;
00256             /* update ind2
00257             */
00258             for(i=0;i<nodes;i++){
00259                 ind2[3*i+2]=ind[3*i+2]*lambda1;
00260             }
00261 
00262             /* compute and test whether x is zero
00263             */
00264             altra(x, v, n, ind2, nodes);
00265             for(i=0;i<n;i++){
00266                 if (x[i]!=0)
00267                     break;
00268             }
00269 
00270             if (i<n){
00271                 break;
00272                 /*x is not zero
00273                  *we have found lambda1
00274                  */
00275             }
00276         }
00277 
00278     }
00279     else{
00280         /*x is a non-zero vector*/
00281         lambda2=lambda;
00282         lambda1=lambda;
00283 
00284         num=0;
00285         while(1){
00286             num++;            
00287 
00288             lambda1=lambda2;
00289             lambda2=lambda2*2;
00290             /* update ind2
00291             */
00292             for(i=0;i<nodes;i++){
00293                 ind2[3*i+2]=ind[3*i+2]*lambda2;
00294             }
00295 
00296             /* compute and test whether x is zero
00297             */
00298             altra(x, v, n, ind2, nodes);
00299             for(i=0;i<n;i++){
00300                 if (x[i]!=0)
00301                     break;
00302             }
00303 
00304             if (i>=n){
00305                 break;
00306                 /*x is a zero vector
00307                  *we have found lambda2
00308                  */
00309             }
00310         }
00311     }    
00312 
00313     /*
00314        printf("\n num=%d, lambda1=%2.5f, lambda2=%2.5f",num, lambda1,lambda2);
00315        */
00316 
00317     while ( fabs(lambda2-lambda1) > lambda2 * 1e-10 ){
00318 
00319         num++;
00320 
00321         lambda=(lambda1+lambda2)/2;
00322 
00323         /* update ind2
00324         */
00325         for(i=0;i<nodes;i++){
00326             ind2[3*i+2]=ind[3*i+2]*lambda;
00327         }
00328 
00329         /* compute and test whether x is zero
00330         */
00331         altra(x, v, n, ind2, nodes);
00332         for(i=0;i<n;i++){
00333             if (x[i]!=0)
00334                 break;
00335         }
00336 
00337         if (i>=n){
00338             lambda2=lambda;
00339         }
00340         else{
00341             lambda1=lambda;
00342         }
00343 
00344         /*
00345            printf("\n lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
00346            */
00347     }
00348 
00349     /*
00350        printf("\n num=%d",num);
00351 
00352        printf("   lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
00353 
00354 */
00355 
00356     free(x);
00357     free(ind2);
00358 
00359     return lambda2;
00360 }
00361 
00362 double findLambdaMax_mt(double *V, int n, int k, double *ind, int nodes)
00363 {
00364     int i, j;
00365 
00366     double *v=(double *)malloc(sizeof(double)*k);
00367     double lambda;
00368 
00369     double lambdaMax=0;
00370 
00371     for (i=0;i<n;i++){
00372         /*
00373          * copy a row of V to v
00374          *         
00375          */
00376         for(j=0;j<k;j++)
00377             v[j]=V[j*n + i];
00378 
00379         lambda = findLambdaMax(v, k, ind, nodes);
00380 
00381         /*
00382            printf("\n   lambda=%5.2f",lambda);        
00383            */
00384 
00385         if (lambda>lambdaMax)
00386             lambdaMax=lambda;
00387     }
00388 
00389     /*
00390        printf("\n *lambdaMax=%5.2f",*lambdaMax);
00391        */
00392 
00393     free(v);
00394     return lambdaMax;
00395 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation