00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/io/SGIO.h>
00012 #include <shogun/labels/MulticlassLabels.h>
00013 #include <shogun/multiclass/GMNPSVM.h>
00014 #include <shogun/multiclass/GMNPLib.h>
00015 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00016
00017 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW))
00018 #define MINUS_INF INT_MIN
00019 #define PLUS_INF INT_MAX
00020 #define KDELTA(A,B) (A==B)
00021 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00022
00023 using namespace shogun;
00024
00025 CGMNPSVM::CGMNPSVM()
00026 : CMulticlassSVM(new CMulticlassOneVsRestStrategy())
00027 {
00028 init();
00029 }
00030
00031 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab)
00032 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab)
00033 {
00034 init();
00035 }
00036
00037 CGMNPSVM::~CGMNPSVM()
00038 {
00039 if (m_basealphas != NULL) SG_FREE(m_basealphas);
00040 }
00041
00042 void
00043 CGMNPSVM::init()
00044 {
00045 m_parameters->add_matrix(&m_basealphas,
00046 &m_basealphas_y, &m_basealphas_x,
00047 "m_basealphas",
00048 "Is the basic untransformed alpha.");
00049
00050 m_basealphas = NULL, m_basealphas_y = 0, m_basealphas_x = 0;
00051 }
00052
00053 bool CGMNPSVM::train_machine(CFeatures* data)
00054 {
00055 ASSERT(m_kernel);
00056 ASSERT(m_labels && m_labels->get_num_labels());
00057 ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00058
00059 if (data)
00060 {
00061 if (m_labels->get_num_labels() != data->get_num_vectors())
00062 {
00063 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00064 " not match number of labels (%d)\n", get_name(),
00065 data->get_num_vectors(), m_labels->get_num_labels());
00066 }
00067 m_kernel->init(data, data);
00068 }
00069
00070 int32_t num_data = m_labels->get_num_labels();
00071 int32_t num_classes = m_multiclass_strategy->get_num_classes();
00072 int32_t num_virtual_data= num_data*(num_classes-1);
00073
00074 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00075
00076 float64_t* vector_y = SG_MALLOC(float64_t, num_data);
00077 for (int32_t i=0; i<num_data; i++)
00078 {
00079 vector_y[i] = ((CMulticlassLabels*) m_labels)->get_label(i)+1;
00080
00081 }
00082
00083 float64_t C = get_C();
00084 int32_t tmax = 1000000000;
00085 float64_t tolabs = 0;
00086 float64_t tolrel = get_epsilon();
00087
00088 float64_t reg_const=0;
00089 if( C!=0 )
00090 reg_const = 1/(2*C);
00091
00092
00093 float64_t* alpha = SG_MALLOC(float64_t, num_virtual_data);
00094 float64_t* vector_c = SG_MALLOC(float64_t, num_virtual_data);
00095 memset(vector_c, 0, num_virtual_data*sizeof(float64_t));
00096
00097 float64_t thlb = 10000000000.0;
00098 int32_t t = 0;
00099 float64_t* History = NULL;
00100 int32_t verb = 0;
00101
00102 CGMNPLib mnp(vector_y,m_kernel,num_data, num_virtual_data, num_classes, reg_const);
00103
00104 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00105 tolabs, tolrel, thlb, alpha, &t, &History, verb);
00106
00107
00108 float64_t* all_alphas= SG_MALLOC(float64_t, num_classes*num_data);
00109 memset(all_alphas,0,num_classes*num_data*sizeof(float64_t));
00110
00111
00112 float64_t* all_bs=SG_MALLOC(float64_t, num_classes);
00113 memset(all_bs,0,num_classes*sizeof(float64_t));
00114
00115
00116 for(int32_t i=0; i < num_classes; i++ )
00117 {
00118 for(int32_t j=0; j < num_virtual_data; j++ )
00119 {
00120 int32_t inx1=0;
00121 int32_t inx2=0;
00122
00123 mnp.get_indices2( &inx1, &inx2, j );
00124
00125 all_alphas[(inx1*num_classes)+i] +=
00126 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00127 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00128 }
00129 }
00130
00131 create_multiclass_svm(num_classes);
00132
00133 for (int32_t i=0; i<num_classes; i++)
00134 {
00135 int32_t num_sv=0;
00136 for (int32_t j=0; j<num_data; j++)
00137 {
00138 if (all_alphas[j*num_classes+i] != 0)
00139 num_sv++;
00140 }
00141 ASSERT(num_sv>0);
00142 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00143
00144 CSVM* svm=new CSVM(num_sv);
00145
00146 int32_t k=0;
00147 for (int32_t j=0; j<num_data; j++)
00148 {
00149 if (all_alphas[j*num_classes+i] != 0)
00150 {
00151 svm->set_alpha(k, all_alphas[j*num_classes+i]);
00152 svm->set_support_vector(k, j);
00153 k++;
00154 }
00155 }
00156
00157 svm->set_bias(all_bs[i]);
00158 set_svm(i, svm);
00159 }
00160
00161 if (m_basealphas != NULL) SG_FREE(m_basealphas);
00162 m_basealphas_y = num_classes, m_basealphas_x = num_data;
00163 m_basealphas = SG_MALLOC(float64_t, m_basealphas_y*m_basealphas_x);
00164 for (index_t i=0; i<m_basealphas_y*m_basealphas_x; i++)
00165 m_basealphas[i] = 0.0;
00166
00167 for(index_t j=0; j<num_virtual_data; j++)
00168 {
00169 index_t inx1=0, inx2=0;
00170
00171 mnp.get_indices2(&inx1, &inx2, j);
00172 m_basealphas[inx1*m_basealphas_y + (inx2-1)] = alpha[j];
00173 }
00174
00175 SG_FREE(vector_c);
00176 SG_FREE(alpha);
00177 SG_FREE(all_alphas);
00178 SG_FREE(all_bs);
00179 SG_FREE(vector_y);
00180 SG_FREE(History);
00181
00182 return true;
00183 }
00184
00185 float64_t*
00186 CGMNPSVM::get_basealphas_ptr(index_t* y, index_t* x)
00187 {
00188 if (y == NULL || x == NULL) return NULL;
00189
00190 *y = m_basealphas_y, *x = m_basealphas_x;
00191 return m_basealphas;
00192 }