00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Written (W) 2007-2009 Soeren Sonnenburg 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _KMEANS_H__ 00013 #define _KMEANS_H__ 00014 00015 #include <stdio.h> 00016 #include <shogun/lib/common.h> 00017 #include <shogun/io/SGIO.h> 00018 #include <shogun/features/DenseFeatures.h> 00019 #include <shogun/distance/Distance.h> 00020 #include <shogun/machine/DistanceMachine.h> 00021 00022 namespace shogun 00023 { 00024 class CDistanceMachine; 00025 00039 class CKMeans : public CDistanceMachine 00040 { 00041 public: 00043 CKMeans(); 00044 00050 CKMeans(int32_t k, CDistance* d); 00051 virtual ~CKMeans(); 00052 00053 00054 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00055 00056 00060 virtual EMachineType get_classifier_type() { return CT_KMEANS; } 00061 00067 virtual bool load(FILE* srcfile); 00068 00074 virtual bool save(FILE* dstfile); 00075 00080 void set_k(int32_t p_k); 00081 00086 int32_t get_k(); 00087 00092 void set_max_iter(int32_t iter); 00093 00098 float64_t get_max_iter(); 00099 00104 SGVector<float64_t> get_radiuses(); 00105 00110 SGMatrix<float64_t> get_cluster_centers(); 00111 00116 int32_t get_dimensions(); 00117 00119 virtual const char* get_name() const { return "KMeans"; } 00120 00121 protected: 00127 void clustknb(bool use_old_mus, float64_t *mus_start); 00128 00137 virtual bool train_machine(CFeatures* data=NULL); 00138 00140 virtual void store_model_features(); 00141 00142 virtual bool train_require_labels() const { return false; } 00143 00144 private: 00145 void init(); 00146 00147 protected: 00149 int32_t max_iter; 00150 00152 int32_t k; 00153 00155 int32_t dimensions; 00156 00158 SGVector<float64_t> R; 00159 00160 private: 00161 /* temporary variable for weighting over the train data */ 00162 SGVector<float64_t> Weights; 00163 00164 /* temp variable for cluster centers */ 00165 SGMatrix<float64_t> mus; 00166 00167 }; 00168 } 00169 #endif 00170