00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <shogun/mathematics/munkres.h>
00021
00022 #include <iostream>
00023 #include <cmath>
00024
00025 using namespace shogun;
00026
00027 bool Munkres::find_uncovered_in_matrix(double item, int &row, int &col)
00028 {
00029 for (row=0; row < matrix.num_rows; row++)
00030 if (!row_mask[row])
00031 for (col=0; col < matrix.num_cols; col++)
00032 if (!col_mask[col])
00033 if (matrix(row,col) == item)
00034 return true;
00035
00036 return false;
00037 }
00038
00039 bool Munkres::pair_in_list(const std::pair<int,int> &needle, const std::list<std::pair<int,int> > &haystack)
00040 {
00041 for (std::list<std::pair<int,int> >::const_iterator i=haystack.begin(); i != haystack.end(); i++)
00042 {
00043 if ( needle == *i )
00044 return true;
00045 }
00046
00047 return false;
00048 }
00049
00050 int Munkres::step1(void)
00051 {
00052 for (int row=0; row < matrix.num_rows; row++)
00053 for (int col=0; col < matrix.num_cols; col++)
00054 if (matrix(row,col) == 0)
00055 {
00056 bool isstarred=false;
00057 for (int nrow=0; nrow < matrix.num_rows; nrow++)
00058 if (mask_matrix(nrow,col) == STAR)
00059 {
00060 isstarred=true;
00061 break;
00062 }
00063
00064 if (!isstarred)
00065 {
00066 for (int ncol=0; ncol < matrix.num_cols; ncol++)
00067 if ( mask_matrix(row,ncol) == STAR )
00068 {
00069 isstarred=true;
00070 break;
00071 }
00072 }
00073
00074 if (!isstarred)
00075 {
00076 mask_matrix(row,col)=STAR;
00077 }
00078 }
00079
00080 return 2;
00081 }
00082
00083 int Munkres::step2(void)
00084 {
00085 int rows=matrix.num_rows;
00086 int cols=matrix.num_cols;
00087 int covercount=0;
00088 for (int row=0; row < rows; row++)
00089 for (int col=0; col < cols; col++)
00090 if (mask_matrix(row,col) == STAR)
00091 {
00092 col_mask[col]=true;
00093 covercount++;
00094 }
00095
00096 int k=rows < cols ? rows : cols;
00097
00098 if (covercount >= k)
00099 {
00100 return 0;
00101 }
00102
00103 return 3;
00104 }
00105
00106 int Munkres::step3(void)
00107 {
00108
00109
00110
00111
00112
00113
00114
00115 if (find_uncovered_in_matrix(0, saverow, savecol))
00116 {
00117 mask_matrix(saverow,savecol)=PRIME;
00118 }
00119 else
00120 {
00121 return 5;
00122 }
00123
00124 for (int ncol = 0; ncol < matrix.num_cols; ncol++)
00125 if (mask_matrix(saverow,ncol) == STAR)
00126 {
00127 row_mask[saverow]=true;
00128 col_mask[ncol]=false;
00129 return 3;
00130 }
00131
00132 return 4;
00133 }
00134
00135 int Munkres::step4(void)
00136 {
00137 int rows=matrix.num_rows;
00138 int cols=matrix.num_cols;
00139
00140 std::list<std::pair<int,int> > seq;
00141
00142 std::pair<int,int> z0(saverow, savecol);
00143 std::pair<int,int> z1(-1,-1);
00144 std::pair<int,int> z2n(-1,-1);
00145 seq.insert(seq.end(), z0);
00146 int row, col=savecol;
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159 bool madepair;
00160 do
00161 {
00162 madepair=false;
00163 for (row=0; row < rows; row++)
00164 if (mask_matrix(row,col) == STAR)
00165 {
00166 z1.first=row;
00167 z1.second=col;
00168 if (pair_in_list(z1, seq))
00169 continue;
00170
00171 madepair=true;
00172 seq.insert(seq.end(), z1);
00173 break;
00174 }
00175
00176 if (!madepair)
00177 break;
00178
00179 madepair=false;
00180
00181 for (col = 0; col < cols; col++)
00182 if (mask_matrix(row,col) == PRIME)
00183 {
00184 z2n.first=row;
00185 z2n.second=col;
00186 if (pair_in_list(z2n, seq))
00187 continue;
00188 madepair=true;
00189 seq.insert(seq.end(), z2n);
00190 break;
00191 }
00192 }
00193 while (madepair);
00194
00195 for (std::list<std::pair<int,int> >::iterator i=seq.begin();
00196 i != seq.end();
00197 i++)
00198 {
00199
00200 if (mask_matrix(i->first,i->second) == STAR)
00201 mask_matrix(i->first,i->second)=NORMAL;
00202
00203
00204
00205 if (mask_matrix(i->first,i->second) == PRIME)
00206 mask_matrix(i->first,i->second)=STAR;
00207 }
00208
00209
00210 for (int rowi=0; rowi < mask_matrix.num_rows; rowi++)
00211 {
00212 for (int coli=0; coli < mask_matrix.num_cols; coli++)
00213 {
00214 if (mask_matrix(rowi,coli) == PRIME)
00215 mask_matrix(rowi,coli) = NORMAL;
00216 }
00217 }
00218
00219 for (int i=0; i < rows; i++)
00220 row_mask[i]=false;
00221
00222 for (int i=0; i < cols; i++)
00223 col_mask[i]=false;
00224
00225
00226 return 2;
00227 }
00228
00229 int Munkres::step5(void)
00230 {
00231 int rows=matrix.num_rows;
00232 int cols=matrix.num_cols;
00233
00234
00235
00236
00237
00238
00239
00240
00241 double h=0;
00242 for (int row=0; row < rows; row++)
00243 {
00244 if (!row_mask[row])
00245 for (int col=0; col < cols; col++)
00246 if (!col_mask[col])
00247 if ((h > matrix(row,col) && matrix(row,col) != 0) || h == 0)
00248 h = matrix(row,col);
00249 }
00250
00251 for (int row=0; row < rows; row++)
00252 if (row_mask[row])
00253 for (int col=0; col < cols; col++)
00254 matrix(row,col)+=h;
00255
00256 for (int col=0; col < cols; col++)
00257 if (!col_mask[col])
00258 for (int row=0; row < rows; row++)
00259 matrix(row,col)-=h;
00260
00261 return 3;
00262 }
00263
00264 void Munkres::solve(SGMatrix<double> &m)
00265 {
00266
00267
00268
00269
00270
00271
00272
00273 double highValue=0;
00274 for (int row=0; row < m.num_rows; row++)
00275 {
00276 for (int col=0; col < m.num_cols; col++)
00277 {
00278 if (m(row,col) != INFINITY && m(row,col) > highValue)
00279 highValue = m(row,col);
00280 }
00281 }
00282 highValue++;
00283
00284 for (int row=0; row < m.num_rows; row++)
00285 for (int col=0; col < m.num_cols; col++)
00286 if (m(row,col) == INFINITY)
00287 m(row,col)=highValue;
00288
00289 bool notdone=true;
00290 int step=1;
00291
00292 mask_matrix.zero();
00293 std::copy(m.matrix, m.matrix + m.num_cols*m.num_rows, matrix.matrix);
00294
00295
00296 row_mask=SG_MALLOC(bool, matrix.num_rows);
00297 col_mask=SG_MALLOC(bool, matrix.num_cols);
00298 for (int i=0; i < matrix.num_rows; i++)
00299 row_mask[i] = false;
00300
00301 for (int i=0; i < matrix.num_cols; i++)
00302 col_mask[i] = false;
00303
00304 while (notdone)
00305 {
00306 switch (step)
00307 {
00308 case 0:
00309 notdone=false;
00310 break;
00311 case 1:
00312 step=step1();
00313 break;
00314 case 2:
00315 step=step2();
00316 break;
00317 case 3:
00318 step=step3();
00319 break;
00320 case 4:
00321 step=step4();
00322 break;
00323 case 5:
00324 step=step5();
00325 break;
00326 }
00327 }
00328
00329
00330 for (int row=0; row < matrix.num_rows; row++)
00331 for (int col=0; col < matrix.num_cols; col++)
00332 if (mask_matrix(row,col) == STAR)
00333 matrix(row,col)=0;
00334 else
00335 matrix(row,col)=-1;
00336
00337 std::copy(matrix.matrix, matrix.matrix + m.num_cols*m.num_rows, m.matrix);
00338
00339 SG_FREE(row_mask);
00340 SG_FREE(col_mask);
00341 }
00342