00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/clustering/Hierarchical.h>
00012 #include <shogun/distance/Distance.h>
00013 #include <shogun/labels/Labels.h>
00014 #include <shogun/features/Features.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/base/Parallel.h>
00017
00018 #ifdef HAVE_PTHREAD
00019 #include <pthread.h>
00020 #endif
00021
00022 using namespace shogun;
00023
00024 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00025 struct pair
00026 {
00028 int32_t idx1;
00030 int32_t idx2;
00031 };
00032 #endif // DOXYGEN_SHOULD_SKIP_THIS
00033
00034 CHierarchical::CHierarchical()
00035 : CDistanceMachine(), merges(3), dimensions(0), assignment(NULL),
00036 table_size(0), pairs(NULL), merge_distance(NULL)
00037 {
00038 }
00039
00040 CHierarchical::CHierarchical(int32_t merges_, CDistance* d)
00041 : CDistanceMachine(), merges(merges_), dimensions(0), assignment(NULL),
00042 table_size(0), pairs(NULL), merge_distance(NULL)
00043 {
00044 set_distance(d);
00045 }
00046
00047 CHierarchical::~CHierarchical()
00048 {
00049 SG_FREE(merge_distance);
00050 SG_FREE(assignment);
00051 SG_FREE(pairs);
00052 }
00053
00054 EMachineType CHierarchical::get_classifier_type()
00055 {
00056 return CT_HIERARCHICAL;
00057 }
00058
00059 bool CHierarchical::train_machine(CFeatures* data)
00060 {
00061 ASSERT(distance);
00062
00063 if (data)
00064 distance->init(data, data);
00065
00066 CFeatures* lhs=distance->get_lhs();
00067 ASSERT(lhs);
00068
00069 int32_t num=lhs->get_num_vectors();
00070 ASSERT(num>0);
00071
00072 const int32_t num_pairs=num*(num-1)/2;
00073
00074 SG_FREE(merge_distance);
00075 merge_distance=SG_MALLOC(float64_t, num);
00076 SGVector<float64_t>::fill_vector(merge_distance, num, -1.0);
00077
00078 SG_FREE(assignment);
00079 assignment=SG_MALLOC(int32_t, num);
00080 SGVector<int32_t>::range_fill_vector(assignment, num);
00081
00082 SG_FREE(pairs);
00083 pairs=SG_MALLOC(int32_t, 2*num);
00084 SGVector<int32_t>::fill_vector(pairs, 2*num, -1);
00085
00086 pair* index=SG_MALLOC(pair, num_pairs);
00087 float64_t* distances=SG_MALLOC(float64_t, num_pairs);
00088
00089 int32_t offs=0;
00090 for (int32_t i=0; i<num; i++)
00091 {
00092 for (int32_t j=i+1; j<num; j++)
00093 {
00094 distances[offs]=distance->distance(i,j);
00095 index[offs].idx1=i;
00096 index[offs].idx2=j;
00097 offs++;
00098 }
00099 SG_PROGRESS(i, 0, num-1);
00100 }
00101
00102 CMath::qsort_index<float64_t,pair>(distances, index, (num-1)*num/2);
00103
00104
00105 int32_t k=-1;
00106 int32_t l=0;
00107 for (; l<num && (num-l)>=merges && k<num_pairs-1; l++)
00108 {
00109 while (k<num_pairs-1)
00110 {
00111 k++;
00112
00113 int32_t i=index[k].idx1;
00114 int32_t j=index[k].idx2;
00115 int32_t c1=assignment[i];
00116 int32_t c2=assignment[j];
00117
00118 if (c1==c2)
00119 continue;
00120
00121 SG_PROGRESS(k, 0, num_pairs-1);
00122
00123 if (c1<c2)
00124 {
00125 pairs[2*l]=c1;
00126 pairs[2*l+1]=c2;
00127 }
00128 else
00129 {
00130 pairs[2*l]=c2;
00131 pairs[2*l+1]=c1;
00132 }
00133 merge_distance[l]=distances[k];
00134
00135 int32_t c=num+l;
00136 for (int32_t m=0; m<num; m++)
00137 {
00138 if (assignment[m] == c1 || assignment[m] == c2)
00139 assignment[m] = c;
00140 }
00141 #ifdef DEBUG_HIERARCHICAL
00142 SG_PRINT("l=%04i i=%04i j=%04i c1=%+04d c2=%+04d c=%+04d dist=%6.6f\n", l,i,j, c1,c2,c, merge_distance[l]);
00143 #endif
00144 break;
00145 }
00146 }
00147
00148 assignment_size=num;
00149 table_size=l-1;
00150 ASSERT(table_size>0);
00151 SG_FREE(distances);
00152 SG_FREE(index);
00153 SG_UNREF(lhs)
00154
00155 return true;
00156 }
00157
00158 bool CHierarchical::load(FILE* srcfile)
00159 {
00160 SG_SET_LOCALE_C;
00161 SG_RESET_LOCALE;
00162 return false;
00163 }
00164
00165 bool CHierarchical::save(FILE* dstfile)
00166 {
00167 SG_SET_LOCALE_C;
00168 SG_RESET_LOCALE;
00169 return false;
00170 }
00171
00172
00173 int32_t CHierarchical::get_merges()
00174 {
00175 return merges;
00176 }
00177
00178 SGVector<int32_t> CHierarchical::get_assignment()
00179 {
00180 return SGVector<int32_t>(assignment,table_size, false);
00181 }
00182
00183 SGVector<float64_t> CHierarchical::get_merge_distances()
00184 {
00185 return SGVector<float64_t>(merge_distance,merges, false);
00186 }
00187
00188 SGMatrix<int32_t> CHierarchical::get_cluster_pairs()
00189 {
00190 return SGMatrix<int32_t>(pairs,2,merges, false);
00191 }
00192
00193
00194 void CHierarchical::store_model_features()
00195 {
00196
00197 }