KRR.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006 Mikio L. Braun
00008  * Written (W) 1999-2009 Soeren Sonnenburg
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include "lib/config.h"
00013 
00014 #ifdef HAVE_LAPACK
00015 #include "regression/KRR.h"
00016 #include "lib/lapack.h"
00017 #include "lib/Mathematics.h"
00018 
00019 using namespace shogun;
00020 
00021 CKRR::CKRR()
00022 : CKernelMachine()
00023 {
00024     alpha=NULL;
00025     tau=1e-6;
00026 }
00027 
00028 CKRR::CKRR(float64_t t, CKernel* k, CLabels* lab)
00029 : CKernelMachine()
00030 {
00031     tau=t;
00032     set_labels(lab);
00033     set_kernel(k);
00034     alpha=NULL;
00035 }
00036 
00037 
00038 CKRR::~CKRR()
00039 {
00040     delete[] alpha;
00041 }
00042 
00043 bool CKRR::train(CFeatures* data)
00044 {
00045     delete[] alpha;
00046 
00047     ASSERT(labels);
00048     if (data)
00049     {
00050         if (labels->get_num_labels() != data->get_num_vectors())
00051             SG_ERROR("Number of training vectors does not match number of labels\n");
00052         kernel->init(data, data);
00053     }
00054     ASSERT(kernel && kernel->has_features());
00055 
00056     // Get kernel matrix
00057     int32_t m=0;
00058     int32_t n=0;
00059     float64_t *K = kernel->get_kernel_matrix<float64_t>(m, n, NULL);
00060     ASSERT(K && m>0 && n>0);
00061 
00062     for(int32_t i=0; i < n; i++)
00063         K[i+i*n]+=tau;
00064 
00065     // Get labels
00066     int32_t numlabels=0;
00067     alpha=labels->get_labels(numlabels);
00068     if (!alpha)
00069         SG_ERROR("No labels set\n");
00070 
00071     if (numlabels!=n)
00072     {
00073         SG_ERROR("Number of labels does not match number of kernel"
00074                 " columns (num_labels=%d cols=%d\n", numlabels, n);
00075     }
00076 
00077     clapack_dposv(CblasRowMajor,CblasUpper, n, 1, K, n, alpha, n);
00078 
00079     delete[] K;
00080     return true;
00081 }
00082 
00083 bool CKRR::load(FILE* srcfile)
00084 {
00085     SG_SET_LOCALE_C;
00086     SG_RESET_LOCALE;
00087     return false;
00088 }
00089 
00090 bool CKRR::save(FILE* dstfile)
00091 {
00092     SG_SET_LOCALE_C;
00093     SG_RESET_LOCALE;
00094     return false;
00095 }
00096 
00097 CLabels* CKRR::classify()
00098 {
00099     ASSERT(kernel);
00100 
00101     // Get kernel matrix
00102     int32_t m=0;
00103     int32_t n=0;
00104     float64_t* K=kernel->get_kernel_matrix<float64_t>(m, n, NULL);
00105     ASSERT(K && m>0 && n>0);
00106     float64_t* Yh=new float64_t[n];
00107 
00108     // predict
00109     // K is symmetric, CblasColMajor is same as CblasRowMajor 
00110     // and used that way in the origin call:
00111     // dgemv('T', m, n, 1.0, K, m, alpha, 1, 0.0, Yh, 1);
00112     int m_int = (int) m;
00113     int n_int = (int) n;
00114     cblas_dgemv(CblasColMajor, CblasTrans, m_int, n_int, 1.0, (double*) K,
00115         m_int, (double*) alpha, 1, 0.0, (double*) Yh, 1);
00116 
00117     delete[] K;
00118 
00119     CLabels* output=new CLabels(n);
00120     output->set_labels(Yh, n);
00121 
00122     delete[] Yh;
00123 
00124     return output;
00125 }
00126 
00127 float64_t CKRR::classify_example(int32_t num)
00128 {
00129     ASSERT(kernel);
00130 
00131     // Get kernel matrix
00132     int32_t m=0;
00133     int32_t n=0;
00134     // TODO: use get_kernel_column instead of computing the whole matrix!
00135     float64_t* K=kernel->get_kernel_matrix<float64_t>(m, n, NULL);
00136     ASSERT(K && m>0 && n>0);
00137     float64_t Yh;
00138 
00139     // predict
00140     Yh = CMath::dot(K + m*num, alpha, m);
00141 
00142     delete[] K;
00143     return Yh;
00144 }
00145 
00146 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation