00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/multiclass/MulticlassSVM.h>
00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00015
00016 using namespace shogun;
00017
00018 CMulticlassSVM::CMulticlassSVM()
00019 :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL), m_C(0)
00020 {
00021 init();
00022 }
00023
00024 CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy)
00025 :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL), m_C(0)
00026 {
00027 init();
00028 }
00029
00030 CMulticlassSVM::CMulticlassSVM(
00031 CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
00032 : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab), m_C(C)
00033 {
00034 init();
00035 }
00036
00037 CMulticlassSVM::~CMulticlassSVM()
00038 {
00039 }
00040
00041 void CMulticlassSVM::init()
00042 {
00043 SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
00044 }
00045
00046 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
00047 {
00048 if (num_classes>0)
00049 {
00050 int32_t num_svms=m_multiclass_strategy->get_num_machines();
00051
00052 m_machines->reset_array();
00053 for (index_t i=0; i<num_svms; ++i)
00054 m_machines->push_back(NULL);
00055
00056 return true;
00057 }
00058 return false;
00059 }
00060
00061 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm)
00062 {
00063 if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm)
00064 {
00065 m_machines->set_element(svm, num);
00066 return true;
00067 }
00068 return false;
00069 }
00070
00071 bool CMulticlassSVM::init_machines_for_apply(CFeatures* data)
00072 {
00073 if (is_data_locked())
00074 {
00075 SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when "
00076 "data_lock was called before. Call data_unlock to allow.");
00077 }
00078
00079 if (!m_kernel)
00080 SG_ERROR("No kernel assigned!\n");
00081
00082 CFeatures* lhs=m_kernel->get_lhs();
00083 if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED)
00084 SG_ERROR("%s: No left hand side specified\n", get_name());
00085
00086 if (m_kernel->get_kernel_type()!=K_COMBINED && !lhs->get_num_vectors())
00087 {
00088 SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to"
00089 " an implementation error in %s, where it was forgotten to set "
00090 "the data (m_svs) indices\n", get_name(),
00091 data->get_name());
00092 }
00093
00094 if (data && m_kernel->get_kernel_type()!=K_COMBINED)
00095 m_kernel->init(lhs, data);
00096 SG_UNREF(lhs);
00097
00098 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00099 {
00100 CSVM *the_svm = (CSVM *)m_machines->get_element(i);
00101 ASSERT(the_svm);
00102 the_svm->set_kernel(m_kernel);
00103 SG_UNREF(the_svm);
00104 }
00105
00106 return true;
00107 }
00108
00109 bool CMulticlassSVM::load(FILE* modelfl)
00110 {
00111 bool result=true;
00112 char char_buffer[1024];
00113 int32_t int_buffer;
00114 float64_t double_buffer;
00115 int32_t line_number=1;
00116 int32_t svm_idx=-1;
00117
00118 SG_SET_LOCALE_C;
00119
00120 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00121 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00122 else
00123 {
00124 char_buffer[15]='\0';
00125 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00126 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00127
00128 line_number++;
00129 }
00130
00131 int_buffer=0;
00132 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00133 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00134
00135 if (!feof(modelfl))
00136 line_number++;
00137
00138 if (int_buffer < 2)
00139 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00140
00141 create_multiclass_svm(int_buffer);
00142
00143 int_buffer=0;
00144 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00145 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00146
00147 if (!feof(modelfl))
00148 line_number++;
00149
00150 if (m_machines->get_num_elements() != int_buffer)
00151 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer);
00152
00153 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00154 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00155
00156 if (!feof(modelfl))
00157 line_number++;
00158
00159 for (int32_t n=0; n<m_machines->get_num_elements(); n++)
00160 {
00161 svm_idx=-1;
00162 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00163 {
00164 result=false;
00165 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00166 }
00167 else
00168 {
00169 char_buffer[4]='\0';
00170 if (strncmp("%SVM", char_buffer, 4)!=0)
00171 {
00172 result=false;
00173 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00174 }
00175
00176 if (svm_idx != n)
00177 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00178
00179 line_number++;
00180 }
00181
00182 int_buffer=0;
00183 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00184 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00185
00186 if (svm_idx != n)
00187 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00188
00189 if (!feof(modelfl))
00190 line_number++;
00191
00192 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00193 CSVM* svm=new CSVM(int_buffer);
00194
00195 double_buffer=0;
00196
00197 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00198 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00199
00200 if (svm_idx != n)
00201 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00202
00203 if (!feof(modelfl))
00204 line_number++;
00205
00206 svm->set_bias(double_buffer);
00207
00208 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00209 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00210
00211 if (svm_idx != n)
00212 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00213
00214 if (!feof(modelfl))
00215 line_number++;
00216
00217 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00218 {
00219 double_buffer=0;
00220 int_buffer=0;
00221
00222 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00223 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00224
00225 if (!feof(modelfl))
00226 line_number++;
00227
00228 svm->set_support_vector(i, int_buffer);
00229 svm->set_alpha(i, double_buffer);
00230 }
00231
00232 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00233 {
00234 result=false;
00235 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00236 }
00237 else
00238 {
00239 char_buffer[3]='\0';
00240 if (strcmp("];", char_buffer)!=0)
00241 {
00242 result=false;
00243 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00244 }
00245 line_number++;
00246 }
00247
00248 set_svm(n, svm);
00249 }
00250
00251 svm_proto()->svm_loaded=result;
00252
00253 SG_RESET_LOCALE;
00254 return result;
00255 }
00256
00257 bool CMulticlassSVM::save(FILE* modelfl)
00258 {
00259 SG_SET_LOCALE_C;
00260
00261 if (!m_kernel)
00262 SG_ERROR("Kernel not defined!\n");
00263
00264 if (m_machines->get_num_elements()<1)
00265 SG_ERROR("Multiclass SVM not trained!\n");
00266
00267 SG_INFO( "Writing model file...");
00268 fprintf(modelfl,"%%MultiClassSVM\n");
00269 fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
00270 fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
00271 fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
00272
00273 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00274 {
00275 CSVM* svm=get_svm(i);
00276 ASSERT(svm);
00277 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1);
00278 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00279 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00280
00281 fprintf(modelfl, "alphas%d=[\n", i);
00282
00283 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00284 {
00285 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00286 svm->get_alpha(j), svm->get_support_vector(j));
00287 }
00288
00289 fprintf(modelfl, "];\n");
00290 }
00291
00292 SG_RESET_LOCALE;
00293 SG_DONE();
00294 return true ;
00295 }