general_altra.cpp

Go to the documentation of this file.
00001 /*   This program is free software: you can redistribute it and/or modify
00002  *   it under the terms of the GNU General Public License as published by
00003  *   the Free Software Foundation, either version 3 of the License, or
00004  *   (at your option) any later version.
00005  *
00006  *   This program is distributed in the hope that it will be useful,
00007  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
00008  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00009  *   GNU General Public License for more details.
00010  *
00011  *   You should have received a copy of the GNU General Public License
00012  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
00013  *
00014  *   Copyright (C) 2009 - 2012 Jun Liu and Jieping Ye 
00015  */
00016 
00017 #include <shogun/lib/slep/tree/general_altra.h>
00018 
00019 void general_altra(double *x, double *v, int n, double *G, double *ind, int nodes, double mult)
00020 {
00021 
00022     int i, j;
00023     double lambda,twoNorm, ratio;
00024 
00025     /*
00026      * test whether the first node is special
00027      */
00028     if ((int) ind[0]==-1){
00029 
00030         /*
00031          *Recheck whether ind[1] equals to zero
00032          */
00033         if ((int) ind[1]!=-1){
00034             printf("\n Error! \n Check ind");
00035             exit(1);
00036         }        
00037 
00038         lambda=mult*ind[2];
00039 
00040         for(j=0;j<n;j++){
00041             if (v[j]>lambda)
00042                 x[j]=v[j]-lambda;
00043             else
00044                 if (v[j]<-lambda)
00045                     x[j]=v[j]+lambda;
00046                 else
00047                     x[j]=0;
00048         }
00049 
00050         i=1;
00051     }
00052     else{
00053         memcpy(x, v, sizeof(double) * n);
00054         i=0;
00055     }
00056 
00057     /*
00058      * sequentially process each node
00059      *
00060      */
00061     for(;i < nodes; i++){
00062         /*
00063          * compute the L2 norm of this group         
00064          */
00065         twoNorm=0;
00066         for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
00067             twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ];        
00068         twoNorm=sqrt(twoNorm);
00069 
00070         lambda=mult*ind[3*i+2];
00071         if (twoNorm>lambda){
00072             ratio=(twoNorm-lambda)/twoNorm;
00073 
00074             /*
00075              * shrinkage this group by ratio
00076              */
00077             for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
00078                 x[(int) G[j]-1 ]*=ratio;            
00079         }
00080         else{
00081             /*
00082              * threshold this group to zero
00083              */
00084             for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
00085                 x[(int) G[j]-1 ]=0;
00086         }
00087     }
00088 }
00089 
00090 void general_altra_mt(double *X, double *V, int n, int k, double *G, double *ind, int nodes, double mult)
00091 {
00092     int i, j;
00093 
00094     double *x=(double *)malloc(sizeof(double)*k);
00095     double *v=(double *)malloc(sizeof(double)*k);
00096 
00097     for (i=0;i<n;i++){
00098         /*
00099          * copy a row of V to v
00100          *         
00101          */
00102         for(j=0;j<k;j++)
00103             v[j]=V[j*n + i];
00104 
00105         general_altra(x, v, k, G, ind, nodes, mult);
00106 
00107         /*
00108          * copy the solution to X         
00109          */        
00110         for(j=0;j<k;j++)
00111             X[j*n+i]=x[j];
00112     }
00113 
00114     free(x);
00115     free(v);
00116 }
00117 
00118 void general_computeLambda2Max(double *lambda2_max, double *x, int n, double *G, double *ind, int nodes)
00119 {
00120     int i, j;
00121     double twoNorm;
00122 
00123     *lambda2_max=0;
00124 
00125 
00126 
00127     for(i=0;i < nodes; i++){
00128         /*
00129          * compute the L2 norm of this group         
00130          */
00131         twoNorm=0;
00132         for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
00133             twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ];        
00134         twoNorm=sqrt(twoNorm);
00135 
00136         twoNorm=twoNorm/ind[3*i+2];
00137 
00138         if (twoNorm >*lambda2_max )
00139             *lambda2_max=twoNorm;        
00140     }
00141 }
00142 
00143 double general_treeNorm(double *x, int ldx, int n, double *G, double *ind, int nodes)
00144 {
00145 
00146     int i, j;
00147     double twoNorm, lambda;
00148 
00149     double tree_norm=0;
00150 
00151     /*
00152      * test whether the first node is special
00153      */
00154     if ((int) ind[0]==-1){
00155 
00156         /*
00157          *Recheck whether ind[1] equals to zero
00158          */
00159         if ((int) ind[1]!=-1){
00160             printf("\n Error! \n Check ind");
00161             exit(1);
00162         }        
00163 
00164         lambda=ind[2];
00165 
00166         for(j=0;j<n;j+=ldx){
00167             tree_norm+=fabs(x[j]);
00168         }
00169 
00170         tree_norm=tree_norm * lambda;
00171 
00172         i=1;
00173     }
00174     else{
00175         i=0;
00176     }
00177 
00178     /*
00179      * sequentially process each node
00180      *
00181      */
00182     for(;i < nodes; i++){
00183         /*
00184          * compute the L2 norm of this group         
00185 
00186 */
00187         twoNorm=0;
00188         for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
00189             twoNorm += x[(int) G[j]-1 ] * x[(int) G[j]-1 ];        
00190         twoNorm=sqrt(twoNorm);
00191 
00192         lambda=ind[3*i+2];
00193 
00194         tree_norm=tree_norm + lambda*twoNorm;
00195     }
00196     return tree_norm;
00197 }
00198 
00199 double general_findLambdaMax(double *v, int n, double *G, double *ind, int nodes)
00200 {
00201 
00202     int i;
00203     double lambda=0,squaredWeight=0, lambda1,lambda2;
00204     double *x=(double *)malloc(sizeof(double)*n);
00205     double *ind2=(double *)malloc(sizeof(double)*nodes*3);
00206     int num=0;
00207 
00208     for(i=0;i<n;i++){
00209         lambda+=v[i]*v[i];
00210     }
00211 
00212     if ( (int)ind[0]==-1 )
00213         squaredWeight=n*ind[2]*ind[2];
00214     else
00215         squaredWeight=ind[2]*ind[2];
00216 
00217     for (i=1;i<nodes;i++){
00218         squaredWeight+=ind[3*i+2]*ind[3*i+2];
00219     }
00220 
00221     /* set lambda to an initial guess
00222     */
00223     lambda=sqrt(lambda/squaredWeight);
00224 
00225     /*
00226        printf("\n\n   lambda=%2.5f",lambda);
00227        */
00228 
00229     /*
00230      *copy ind to ind2,
00231      *and scale the weight 3*i+2
00232      */
00233     for(i=0;i<nodes;i++){
00234         ind2[3*i]=ind[3*i];
00235         ind2[3*i+1]=ind[3*i+1];
00236         ind2[3*i+2]=ind[3*i+2]*lambda;
00237     }
00238 
00239     /* test whether the solution is zero or not
00240     */
00241     general_altra(x, v, n, G, ind2, nodes);    
00242     for(i=0;i<n;i++){
00243         if (x[i]!=0)
00244             break;
00245     }
00246 
00247     if (i>=n) {
00248         /*x is a zero vector*/
00249         lambda2=lambda;
00250         lambda1=lambda;
00251 
00252         num=0;
00253 
00254         while(1){
00255             num++;
00256 
00257             lambda2=lambda;
00258             lambda1=lambda1/2;
00259             /* update ind2
00260             */
00261             for(i=0;i<nodes;i++){
00262                 ind2[3*i+2]=ind[3*i+2]*lambda1;
00263             }
00264 
00265             /* compute and test whether x is zero
00266             */
00267             general_altra(x, v, n, G, ind2, nodes);
00268             for(i=0;i<n;i++){
00269                 if (x[i]!=0)
00270                     break;
00271             }
00272 
00273             if (i<n){
00274                 break;
00275                 /*x is not zero
00276                  *we have found lambda1
00277                  */
00278             }
00279         }
00280 
00281     }
00282     else{
00283         /*x is a non-zero vector*/
00284         lambda2=lambda;
00285         lambda1=lambda;
00286 
00287         num=0;
00288         while(1){
00289             num++;            
00290 
00291             lambda1=lambda2;
00292             lambda2=lambda2*2;
00293             /* update ind2
00294             */
00295             for(i=0;i<nodes;i++){
00296                 ind2[3*i+2]=ind[3*i+2]*lambda2;
00297             }
00298 
00299             /* compute and test whether x is zero
00300             */
00301             general_altra(x, v, n, G, ind2, nodes);
00302             for(i=0;i<n;i++){
00303                 if (x[i]!=0)
00304                     break;
00305             }
00306 
00307             if (i>=n){
00308                 break;
00309                 /*x is a zero vector
00310                  *we have found lambda2
00311                  */
00312             }
00313         }
00314     }    
00315 
00316     /*
00317        printf("\n num=%d, lambda1=%2.5f, lambda2=%2.5f",num, lambda1,lambda2);
00318        */
00319 
00320     while ( fabs(lambda2-lambda1) > lambda2 * 1e-10 ){
00321 
00322         num++;
00323 
00324         lambda=(lambda1+lambda2)/2;
00325 
00326         /* update ind2
00327         */
00328         for(i=0;i<nodes;i++){
00329             ind2[3*i+2]=ind[3*i+2]*lambda;
00330         }
00331 
00332         /* compute and test whether x is zero
00333         */
00334         general_altra(x, v, n, G, ind2, nodes);
00335         for(i=0;i<n;i++){
00336             if (x[i]!=0)
00337                 break;
00338         }
00339 
00340         if (i>=n){
00341             lambda2=lambda;
00342         }
00343         else{
00344             lambda1=lambda;
00345         }
00346 
00347         /*
00348            printf("\n lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
00349            */
00350     }
00351 
00352     /*
00353        printf("\n num=%d",num);
00354 
00355        printf("   lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
00356        */
00357 
00358     free(x);
00359     free(ind2);
00360 
00361     return lambda2;
00362 }
00363 
00364 double general_findLambdaMax_mt(double *V, int n, int k, double *G, double *ind, int nodes)
00365 {
00366     int i, j;
00367 
00368     double *v=(double *)malloc(sizeof(double)*k);
00369     double lambda;
00370 
00371     double lambdaMax=0;
00372 
00373     for (i=0;i<n;i++){
00374         /*
00375          * copy a row of V to v
00376          *         
00377          */
00378         for(j=0;j<k;j++)
00379             v[j]=V[j*n + i];
00380 
00381         lambda = general_findLambdaMax(v, k, G, ind, nodes);
00382 
00383         /*
00384            printf("\n   lambda=%5.2f",lambda);        
00385            */
00386 
00387 
00388         if (lambda>lambdaMax)
00389             lambdaMax=lambda;
00390     }
00391 
00392     /*
00393        printf("\n *lambdaMax=%5.2f",*lambdaMax);
00394        */
00395 
00396     free(v);
00397     return lambdaMax;
00398 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation