SHOGUN  6.1.3
WeightedMajorityVote.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Viktor Gal
8  * Copyright (C) 2013 Viktor Gal
9  */
10
12 #include <shogun/base/Parameter.h>
13 #include <shogun/lib/SGMatrix.h>
14 #include <map>
15
16 using namespace shogun;
17
20 {
21  init();
22  register_parameters();
23 }
24
27 {
28  init();
29  register_parameters();
30  m_weights = weights;
31 }
32
34 {
35
36 }
37
39 {
40  REQUIRE(m_weights.vlen == ensemble_result.num_cols, "The number of results and weights does not match!");
41  SGVector<float64_t> mv(ensemble_result.num_rows);
42  for (index_t i = 0; i < ensemble_result.num_rows; ++i)
43  {
44  SGVector<float64_t> rv = ensemble_result.get_row_vector(i);
45  mv[i] = combine(rv);
46  }
47
48  return mv;
49 }
50
52 {
53  return weighted_combine(ensemble_result);
54 }
55
57 {
58  REQUIRE(m_weights.vlen == ensemble_result.vlen, "The number of results and weights does not match!");
59  std::map<index_t, float64_t> freq;
60  std::map<index_t, float64_t>::iterator it;
61  index_t max_label = -100;
63
64  for (index_t i = 0; i < ensemble_result.vlen; ++i)
65  {
66  if (CMath::is_nan(ensemble_result[i]))
67  continue;
68
69  it = freq.find(ensemble_result[i]);
70  if (it == freq.end())
71  {
72  freq.insert(std::make_pair(ensemble_result[i], m_weights[i]));
73  if (max < m_weights[i])
74  {
75  max_label = ensemble_result[i];
76  max = m_weights[i];
77  }
78  }
79  else
80  {
81  it->second += m_weights[i];
82  if (max < it->second)
83  {
84  max_label = it->first;
85  max = it->second;
86  }
87  }
88  }
89
90  return max_label;
91 }
92
94 {
95  m_weights = w;
96 }
97
99 {
100  return m_weights;
101 }
102
103 void CWeightedMajorityVote::init()
104 {
106 }
107
108 void CWeightedMajorityVote::register_parameters()
109 {
110  SG_ADD(&m_weights, "weights", "Weights for the majority vote", MS_AVAILABLE);
111 }
virtual SGVector< float64_t > combine(const SGMatrix< float64_t > &ensemble_result) const
SGVector< T > get_row_vector(index_t row) const
Definition: SGMatrix.cpp:1211
SGVector< float64_t > get_weights() const
int32_t index_t
Definition: common.h:72
void set_weights(SGVector< float64_t > &w)
#define REQUIRE(x,...)
Definition: SGIO.h:181
static const float64_t ALMOST_NEG_INFTY
almost neg (log) infinity
Definition: Math.h:1872
double float64_t
Definition: common.h:60
index_t num_rows
Definition: SGMatrix.h:495
index_t num_cols
Definition: SGMatrix.h:497
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
static int is_nan(double f)
checks whether a float is nan
Definition: Math.cpp:210
CombinationRule abstract class The CombinationRule defines an interface to how to combine the classif...