SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
WeightedMajorityVote.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Viktor Gal
8  * Copyright (C) 2013 Viktor Gal
9  */
10 
12 #include <shogun/base/Parameter.h>
13 #include <shogun/lib/SGMatrix.h>
14 #include <map>
15 
16 using namespace shogun;
17 
20 {
21  init();
22  register_parameters();
23 }
24 
27 {
28  init();
29  register_parameters();
30  m_weights = weights;
31 }
32 
34 {
35 
36 }
37 
39 {
40  REQUIRE(m_weights.vlen == ensemble_result.num_cols, "The number of results and weights does not match!");
41  SGVector<float64_t> mv(ensemble_result.num_rows);
42  for (index_t i = 0; i < ensemble_result.num_rows; ++i)
43  {
44  SGVector<float64_t> rv = ensemble_result.get_row_vector(i);
45  mv[i] = combine(rv);
46  }
47 
48  return mv;
49 }
50 
52 {
53  return weighted_combine(ensemble_result);
54 }
55 
57 {
58  REQUIRE(m_weights.vlen == ensemble_result.vlen, "The number of results and weights does not match!");
59  std::map<index_t, float64_t> freq;
60  std::map<index_t, float64_t>::iterator it;
61  index_t max_label = -100;
63 
64  for (index_t i = 0; i < ensemble_result.vlen; ++i)
65  {
66  if (CMath::is_nan(ensemble_result[i]))
67  continue;
68 
69  it = freq.find(ensemble_result[i]);
70  if (it == freq.end())
71  {
72  freq.insert(std::make_pair(ensemble_result[i], m_weights[i]));
73  if (max < m_weights[i])
74  {
75  max_label = ensemble_result[i];
76  max = m_weights[i];
77  }
78  }
79  else
80  {
81  it->second += m_weights[i];
82  if (max < it->second)
83  {
84  max_label = it->first;
85  max = it->second;
86  }
87  }
88  }
89 
90  return max_label;
91 }
92 
94 {
95  m_weights = w;
96 }
97 
99 {
100  return m_weights;
101 }
102 
103 void CWeightedMajorityVote::init()
104 {
106 }
107 
108 void CWeightedMajorityVote::register_parameters()
109 {
110  SG_ADD(&m_weights, "weights", "Weights for the majority vote", MS_AVAILABLE);
111 }
virtual SGVector< float64_t > combine(const SGMatrix< float64_t > &ensemble_result) const
SGVector< float64_t > get_weights() const
int32_t index_t
Definition: common.h:62
void set_weights(SGVector< float64_t > &w)
#define REQUIRE(x,...)
Definition: SGIO.h:206
index_t num_cols
Definition: SGMatrix.h:376
index_t num_rows
Definition: SGMatrix.h:374
static const float64_t ALMOST_NEG_INFTY
almost neg (log) infinity
Definition: Math.h:2052
index_t vlen
Definition: SGVector.h:494
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
static int is_nan(double f)
checks whether a float is nan
Definition: Math.cpp:234
CombinationRule abstract class The CombinationRule defines an interface to how to combine the classif...
Matrix::Scalar max(Matrix m)
Definition: Redux.h:68
#define SG_ADD(...)
Definition: SGObject.h:84
SGVector< T > get_row_vector(index_t row) const
Definition: SGMatrix.cpp:1084
virtual float64_t weighted_combine(const SGVector< float64_t > &ensemble_result) const

SHOGUN Machine Learning Toolbox - Documentation