SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
RandomForest.cpp
浏览该文件的文档.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
33 
34 using namespace shogun;
35 
38 {
39  init();
40 }
41 
42 CRandomForest::CRandomForest(int32_t rand_numfeats, int32_t num_bags)
44 {
45  init();
46 
47 
48  set_num_bags(num_bags);
49 
50  if (rand_numfeats>0)
51  dynamic_cast<CRandomCARTree*>(m_machine)->set_feature_subset_size(rand_numfeats);
52 }
53 
54 CRandomForest::CRandomForest(CFeatures* features, CLabels* labels, int32_t num_bags, int32_t rand_numfeats)
56 {
57  init();
58 
59  SG_REF(features);
60  m_features=features;
61  set_labels(labels);
62 
63  set_num_bags(num_bags);
64 
65  if (rand_numfeats>0)
66  dynamic_cast<CRandomCARTree*>(m_machine)->set_feature_subset_size(rand_numfeats);
67 }
68 
69 CRandomForest::CRandomForest(CFeatures* features, CLabels* labels, SGVector<float64_t> weights, int32_t num_bags, int32_t rand_numfeats)
71 {
72  init();
73 
74  SG_REF(features);
75  m_features=features;
76  set_labels(labels);
77  m_weights=weights;
78 
79  set_num_bags(num_bags);
80 
81  if (rand_numfeats>0)
82  dynamic_cast<CRandomCARTree*>(m_machine)->set_feature_subset_size(rand_numfeats);
83 }
84 
86 {
87 }
88 
90 {
91  SG_ERROR("Machine is set as CRandomCART and cannot be changed\n")
92 }
93 
95 {
96  m_weights=weights;
97 }
98 
100 {
101  return m_weights;
102 }
103 
105 {
106  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
107  dynamic_cast<CRandomCARTree*>(m_machine)->set_feature_types(ft);
108 }
109 
111 {
112  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
113  return dynamic_cast<CRandomCARTree*>(m_machine)->get_feature_types();
114 }
115 
117 {
118  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
119  return dynamic_cast<CRandomCARTree*>(m_machine)->get_machine_problem_type();
120 }
121 
123 {
124  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
125  dynamic_cast<CRandomCARTree*>(m_machine)->set_machine_problem_type(mode);
126 }
127 
128 void CRandomForest::set_num_random_features(int32_t rand_featsize)
129 {
130  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
131  REQUIRE(rand_featsize>0,"feature subset size should be greater than 0\n")
132 
133  dynamic_cast<CRandomCARTree*>(m_machine)->set_feature_subset_size(rand_featsize);
134 }
135 
137 {
138  REQUIRE(m_machine,"m_machine is NULL. It is expected to be RandomCARTree\n")
139  return dynamic_cast<CRandomCARTree*>(m_machine)->get_feature_subset_size();
140 }
141 
143 {
144  REQUIRE(m,"Machine supplied is NULL\n")
145  REQUIRE(m_machine,"Reference Machine is NULL\n")
146 
147  CRandomCARTree* tree=dynamic_cast<CRandomCARTree*>(m);
148 
149  SGVector<float64_t> weights(idx.vlen);
150 
151  if (m_weights.vlen==0)
152  {
153  weights.fill_vector(weights.vector,weights.vlen,1.0);
154  }
155  else
156  {
157  for (int32_t i=0;i<idx.vlen;i++)
158  weights[i]=m_weights[idx[i]];
159  }
160 
161  tree->set_weights(weights);
162 
163  // equate the machine problem types - cloning does not do this
164  tree->set_machine_problem_type(dynamic_cast<CRandomCARTree*>(m_machine)->get_machine_problem_type());
165 }
166 
167 void CRandomForest::init()
168 {
170  m_weights=SGVector<float64_t>();
171 
172  SG_ADD(&m_weights,"m_weights","weights",MS_NOT_AVAILABLE)
173 }
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:223
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:169
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:84
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
void set_num_random_features(int32_t rand_featsize)
void set_machine_problem_type(EProblemType mode)
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
void set_feature_types(SGVector< bool > ft)
SGVector< bool > get_feature_types() const
#define SG_REF(x)
Definition: SGObject.h:51
A generic learning machine interface.
Definition: Machine.h:143
void set_weights(SGVector< float64_t > weights)
index_t vlen
Definition: SGVector.h:494
EProblemType
Definition: Machine.h:110
virtual void set_machine_parameters(CMachine *m, SGVector< index_t > idx)
void set_num_bags(int32_t num_bags)
This class implements randomized CART algorithm used in the tree growing process of candidate trees i...
Definition: RandomCARTree.h:48
virtual EProblemType get_machine_problem_type() const
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual void set_machine(CMachine *machine)
: Bagging algorithm i.e. bootstrap aggregating
#define SG_ADD(...)
Definition: SGObject.h:81
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
int32_t get_num_random_features() const
SGVector< float64_t > get_weights() const

SHOGUN 机器学习工具包 - 项目文档