Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _KMEANS_H__
00013 #define _KMEANS_H__
00014
00015 #include <stdio.h>
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "features/SimpleFeatures.h"
00019 #include "distance/Distance.h"
00020 #include "classifier/DistanceMachine.h"
00021
00022 namespace shogun
00023 {
00024 class CDistanceMachine;
00025
00039 class CKMeans : public CDistanceMachine
00040 {
00041 public:
00043 CKMeans();
00044
00050 CKMeans(int32_t k, CDistance* d);
00051 virtual ~CKMeans();
00052
00057 virtual inline EClassifierType get_classifier_type() { return CT_KMEANS; }
00058
00067 virtual bool train(CFeatures* data=NULL);
00068
00074 virtual bool load(FILE* srcfile);
00075
00081 virtual bool save(FILE* dstfile);
00082
00087 inline void set_k(int32_t p_k)
00088 {
00089 ASSERT(p_k>0);
00090 this->k=p_k;
00091 }
00092
00097 inline int32_t get_k()
00098 {
00099 return k;
00100 }
00101
00106 inline void set_max_iter(int32_t iter)
00107 {
00108 ASSERT(iter>0);
00109 max_iter=iter;
00110 }
00111
00116 inline float64_t get_max_iter()
00117 {
00118 return max_iter;
00119 }
00120
00126 inline void get_radi(float64_t*& radi, int32_t& num)
00127 {
00128 radi=R;
00129 num=k;
00130 }
00131
00138 inline void get_centers(float64_t*& centers, int32_t& dim, int32_t& num)
00139 {
00140 centers=mus;
00141 dim=dimensions;
00142 num=k;
00143 }
00144
00150 inline void get_radiuses(float64_t** radii, int32_t* num)
00151 {
00152 size_t sz=sizeof(*R)*k;
00153 *radii=(float64_t*) malloc(sz);
00154 ASSERT(*radii);
00155
00156 memcpy(*radii, R, sz);
00157 *num=k;
00158 }
00159
00166 inline void get_cluster_centers(
00167 float64_t** centers, int32_t* dim, int32_t* num)
00168 {
00169 size_t sz=sizeof(*mus)*dimensions*k;
00170 *centers=(float64_t*) malloc(sz);
00171 ASSERT(*centers);
00172
00173 memcpy(*centers, mus, sz);
00174 *dim=dimensions;
00175 *num=k;
00176 }
00177
00182 inline int32_t get_dimensions()
00183 {
00184 return dimensions;
00185 }
00186
00187 protected:
00193 void clustknb(bool use_old_mus, float64_t *mus_start);
00194
00199 virtual CLabels* classify()
00200 {
00201 SG_NOTIMPLEMENTED;
00202 return NULL;
00203 }
00204
00210 virtual CLabels* classify(CFeatures* data)
00211 {
00212 SG_NOTIMPLEMENTED;
00213 return NULL;
00214 }
00215
00216
00217
00219 inline virtual const char* get_name() const { return "KMeans"; }
00220
00221 protected:
00223 int32_t max_iter;
00224
00226 int32_t k;
00227
00229 int32_t dimensions;
00230
00232 float64_t* R;
00233
00235 float64_t* mus;
00236
00237 private:
00239 float64_t* Weights;
00240 };
00241 }
00242 #endif
00243