SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
MixtureModel.cpp
浏览该文件的文档.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
36 
37 using namespace shogun;
38 
40 {
41  init();
42 }
43 
45 {
46  init();
47  m_components=components;
48  SG_REF(components);
49  m_weights=weights;
50 }
51 
53 {
54  SG_UNREF(m_components);
55 }
56 
58 {
59  REQUIRE(m_components->get_num_elements()>0,"mixture componenents not specified\n")
60  REQUIRE(m_components->get_num_elements()==m_weights.vlen,"number of weights (%d) does not"
61  " match number of components (%d)\n",m_weights.vlen,m_components->get_num_elements())
62 
63  // set training features
64  if (data)
65  {
66  if (!data->has_property(FP_DOT))
67  SG_ERROR("Specified features are not of type CDotFeatures\n")
68  set_features(data);
69  }
70  else if (!features)
71  {
72  SG_ERROR("No features to train on.\n")
73  }
74 
75  // set training points in all components of the mixture
76  for (int32_t i=0;i<m_components->get_num_elements();i++)
77  {
79  comp->set_features(features);
80 
81  SG_UNREF(comp)
82  }
83 
84  CDotFeatures* dotdata=dynamic_cast<CDotFeatures *>(features);
85  REQUIRE(dotdata,"dynamic cast from CFeatures to CDotFeatures returned NULL")
86  int32_t num_vectors=dotdata->get_num_vectors();
87 
88  // set data for EM
90  em->data.alpha=SGMatrix<float64_t>(num_vectors,m_components->get_num_elements());
91  em->data.components=m_components;
92  em->data.weights=m_weights;
93 
94  // run EM
95  bool is_converged=em->iterate_em(m_max_iters,m_conv_tol);
96  if (!is_converged)
97  SG_WARNING("max iterations reached. No convergence yet!\n")
98 
99  SG_UNREF(em)
100  return true;
101 }
102 
104 {
105  REQUIRE(num_param==1,"number of parameters in mixture model is 1"
106  " (i.e. number of components). num_components should be 1. %d supplied\n",num_param)
107 
108  return CMath::log(get_num_components());
109 }
110 
111 float64_t CMixtureModel::get_log_derivative(int32_t num_param, int32_t num_example)
112 {
114  return 0;
115 }
116 
118 {
119  REQUIRE(features,"features not set\n")
120  REQUIRE(features->get_feature_class() == C_DENSE,"Dense features required\n")
121  REQUIRE(features->get_feature_type() == F_DREAL,"Real features required\n")
122 
123  SGVector<float64_t> log_likelihood_component(m_components->get_num_elements());
124  for (int32_t i=0;i<m_components->get_num_elements();i++)
125  {
127  log_likelihood_component[i]=ith_comp->get_log_likelihood_example(num_example)+CMath::log(m_weights[i]);
128 
129  SG_UNREF(ith_comp);
130  }
131 
132  return CMath::log_sum_exp(log_likelihood_component);
133 }
134 
136 {
137  return m_weights;
138 }
139 
141 {
142  m_weights=weights;
143 }
144 
146 {
147  SG_REF(m_components);
148  return m_components;
149 }
150 
152 {
153  if (m_components!=NULL)
154  SG_UNREF(m_components)
155 
156  m_components=components;
157  SG_REF(m_components);
158 }
159 
161 {
162  return m_components->get_num_elements();
163 }
164 
166 {
167  REQUIRE(index<get_num_components(),"index supplied (%d) is greater than total mixture components (%d)\n"
168  ,index,get_num_components())
169  return CDistribution::obtain_from_generic(m_components->get_element(index));
170 }
171 
172 void CMixtureModel::set_max_iters(int32_t max_iters)
173 {
174  m_max_iters=max_iters;
175 }
176 
178 {
179  return m_max_iters;
180 }
181 
183 {
184  m_conv_tol=conv_tol;
185 }
186 
188 {
189  return m_conv_tol;
190 }
191 
193 {
194  // TBD
196  return SGVector<float64_t>();
197 }
198 
200 {
201  // TBD
203  return point;
204 }
205 
206 void CMixtureModel::init()
207 {
208  m_components=NULL;
209  m_weights=SGVector<float64_t>();
210  m_conv_tol=1e-8;
211  m_max_iters=1000;
212 
213  SG_ADD((CSGObject**)&m_components,"m_components","components of mixture",MS_NOT_AVAILABLE);
214  SG_ADD(&m_weights,"m_weights","weights of components",MS_NOT_AVAILABLE);
215  SG_ADD(&m_conv_tol,"m_conv_tol","convergence tolerance",MS_NOT_AVAILABLE);
216  SG_ADD(&m_max_iters,"m_max_iters","max number of iterations",MS_NOT_AVAILABLE);
217 }
SGVector< float64_t > cluster(SGVector< float64_t > point)
int32_t index_t
Definition: common.h:62
virtual void set_features(CFeatures *f)
Definition: Distribution.h:160
SGVector< float64_t > weights
Definition: MixModelData.h:52
static CDistribution * obtain_from_generic(CSGObject *object)
virtual int32_t get_num_vectors() const =0
SGMatrix< float64_t > alpha
Definition: MixModelData.h:48
#define SG_ERROR(...)
Definition: SGIO.h:129
void set_convergence_tolerance(float64_t epsilon)
#define REQUIRE(x,...)
Definition: SGIO.h:206
#define SG_NOTIMPLEMENTED
Definition: SGIO.h:139
CDynamicObjectArray * components
Definition: MixModelData.h:50
bool train(CFeatures *data=NULL)
Base class Distribution from which all methods implementing a distribution are derived.
Definition: Distribution.h:44
virtual float64_t get_log_derivative(int32_t num_param, int32_t num_example)
Features that support dot products among other operations.
Definition: DotFeatures.h:44
This is the implementation of EM specialized for Mixture models.
#define SG_REF(x)
Definition: SGObject.h:51
virtual float64_t get_log_likelihood_example(int32_t num_example)
index_t vlen
Definition: SGVector.h:494
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
double float64_t
Definition: common.h:50
index_t get_num_components() const
static T log_sum_exp(SGVector< T > values)
Definition: Math.h:1242
virtual EFeatureClass get_feature_class() const =0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
CDistribution * get_component(index_t index) const
SGVector< float64_t > sample()
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
bool iterate_em(int32_t max_iters=10000, float64_t epsilon=1e-8)
Definition: EMBase.h:72
int32_t get_max_iters() const
void set_max_iters(int32_t max_iters)
The class Features is the base class of all feature objects.
Definition: Features.h:68
CDynamicObjectArray * get_components() const
static float64_t log(float64_t v)
Definition: Math.h:922
float64_t get_log_model_parameter(int32_t num_param=1)
float64_t get_convergence_tolerance() const
CSGObject * get_element(int32_t index) const
void set_components(CDynamicObjectArray *components)
#define SG_WARNING(...)
Definition: SGIO.h:128
#define SG_ADD(...)
Definition: SGObject.h:81
SGVector< float64_t > get_weights() const
void set_weights(SGVector< float64_t > weights)
virtual float64_t get_log_likelihood_example(int32_t num_example)=0
bool has_property(EFeatureProperty p) const
Definition: Features.cpp:295
virtual EFeatureType get_feature_type() const =0

SHOGUN 机器学习工具包 - 项目文档