00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/ConjugateIndex.h>
00012 #ifdef HAVE_LAPACK
00013 #include <shogun/machine/Machine.h>
00014 #include <shogun/features/Features.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/MulticlassLabels.h>
00017 #include <shogun/mathematics/lapack.h>
00018 #include <shogun/mathematics/Math.h>
00019 #include <shogun/lib/Signal.h>
00020
00021 using namespace shogun;
00022
00023 CConjugateIndex::CConjugateIndex() : CMachine()
00024 {
00025 m_classes = NULL;
00026 m_features = NULL;
00027 };
00028
00029 CConjugateIndex::CConjugateIndex(CFeatures* train_features, CLabels* train_labels) : CMachine()
00030 {
00031 m_features = NULL;
00032 set_features(train_features);
00033 set_labels(train_labels);
00034 m_classes = NULL;
00035 };
00036
00037 CConjugateIndex::~CConjugateIndex()
00038 {
00039 clean_classes();
00040 SG_UNREF(m_features);
00041 };
00042
00043 void CConjugateIndex::set_features(CFeatures* features)
00044 {
00045 ASSERT(features->get_feature_class()==C_DENSE);
00046 SG_REF(features);
00047 SG_UNREF(m_features);
00048 m_features = (CDenseFeatures<float64_t>*)features;
00049 }
00050
00051 CDenseFeatures<float64_t>* CConjugateIndex::get_features()
00052 {
00053 SG_REF(m_features);
00054 return m_features;
00055 }
00056
00057 void CConjugateIndex::clean_classes()
00058 {
00059 if (m_classes)
00060 {
00061 for (int32_t i=0; i<m_num_classes; i++)
00062 m_classes[i]=SGMatrix<float64_t>();
00063
00064 delete[] m_classes;
00065 }
00066 }
00067
00068 bool CConjugateIndex::train_machine(CFeatures* data)
00069 {
00070 if (data)
00071 set_features(data);
00072
00073 ASSERT(m_labels);
00074 ASSERT(m_labels->get_label_type()==LT_MULTICLASS);
00075
00076 m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00077 ASSERT(m_num_classes>=2);
00078 clean_classes();
00079
00080 int32_t num_vectors;
00081 int32_t num_features;
00082 float64_t* feature_matrix = m_features->get_feature_matrix(num_features,num_vectors);
00083
00084 m_classes = new SGMatrix<float64_t>[m_num_classes]();
00085 for (int32_t i=0; i<m_num_classes; i++)
00086 m_classes[i] = SGMatrix<float64_t>(num_features,num_features);
00087
00088 m_feature_vector = SGVector<float64_t>(num_features);
00089
00090 SG_PROGRESS(0,0,m_num_classes-1);
00091
00092 for (int32_t label=0; label<m_num_classes; label++)
00093 {
00094 int32_t count = 0;
00095 for (int32_t i=0; i<num_vectors; i++)
00096 {
00097 if (((CMulticlassLabels*) m_labels)->get_int_label(i) == label)
00098 count++;
00099 }
00100
00101 SGMatrix<float64_t> class_feature_matrix(num_features,count);
00102 SGMatrix<float64_t> matrix(count,count);
00103 SGMatrix<float64_t> helper_matrix(num_features,count);
00104
00105 count = 0;
00106 for (int32_t i=0; i<num_vectors; i++)
00107 {
00108 if (((CMulticlassLabels*) m_labels)->get_label(i) == label)
00109 {
00110 memcpy(class_feature_matrix.matrix+count*num_features,
00111 feature_matrix+i*num_features,
00112 sizeof(float64_t)*num_features);
00113 count++;
00114 }
00115 }
00116
00117 cblas_dgemm(CblasColMajor,CblasTrans,CblasNoTrans,
00118 count,count,num_features,
00119 1.0,class_feature_matrix.matrix,num_features,
00120 class_feature_matrix.matrix,num_features,
00121 0.0,matrix.matrix,count);
00122
00123 SGMatrix<float64_t>::inverse(matrix);
00124
00125 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasTrans,
00126 count,num_features,count,
00127 1.0,matrix.matrix,count,
00128 class_feature_matrix.matrix,num_features,
00129 0.0,helper_matrix.matrix,count);
00130
00131 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,
00132 num_features,num_features,count,
00133 1.0,class_feature_matrix.matrix,num_features,
00134 helper_matrix.matrix,count,
00135 0.0,m_classes[label].matrix,num_features);
00136
00137 SG_PROGRESS(label+1,0,m_num_classes);
00138 }
00139 SG_DONE();
00140
00141 return true;
00142 };
00143
00144 CMulticlassLabels* CConjugateIndex::apply_multiclass(CFeatures* data)
00145 {
00146 if (data)
00147 set_features(data);
00148
00149 ASSERT(m_features);
00150
00151 ASSERT(m_classes);
00152 ASSERT(m_num_classes>1);
00153 ASSERT(m_features->get_num_features()==m_feature_vector.vlen);
00154
00155 int32_t num_vectors = m_features->get_num_vectors();
00156
00157 CMulticlassLabels* predicted_labels = new CMulticlassLabels(num_vectors);
00158
00159 for (int32_t i=0; i<num_vectors;i++)
00160 {
00161 SG_PROGRESS(i,0,num_vectors-1);
00162 predicted_labels->set_label(i,apply_one(i));
00163 }
00164 SG_DONE();
00165
00166 return predicted_labels;
00167 };
00168
00169 float64_t CConjugateIndex::conjugate_index(SGVector<float64_t> feature_vector, int32_t label)
00170 {
00171 int32_t num_features = feature_vector.vlen;
00172 float64_t norm = cblas_ddot(num_features,feature_vector.vector,1,
00173 feature_vector.vector,1);
00174
00175 cblas_dgemv(CblasColMajor,CblasNoTrans,
00176 num_features,num_features,
00177 1.0,m_classes[label].matrix,num_features,
00178 feature_vector.vector,1,
00179 0.0,m_feature_vector.vector,1);
00180
00181 float64_t product = cblas_ddot(num_features,feature_vector.vector,1,
00182 m_feature_vector.vector,1);
00183 return product/norm;
00184 };
00185
00186 float64_t CConjugateIndex::apply_one(int32_t index)
00187 {
00188 int32_t predicted_label = 0;
00189 float64_t max_conjugate_index = 0.0;
00190 float64_t current_conjugate_index;
00191
00192 SGVector<float64_t> feature_vector = m_features->get_feature_vector(index);
00193 for (int32_t i=0; i<m_num_classes; i++)
00194 {
00195 current_conjugate_index = conjugate_index(feature_vector,i);
00196
00197 if (current_conjugate_index > max_conjugate_index)
00198 {
00199 max_conjugate_index = current_conjugate_index;
00200 predicted_label = i;
00201 }
00202 }
00203
00204 return predicted_label;
00205 };
00206
00207 #endif