SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DisjointSet.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
12 #include <shogun/base/Parameter.h>
13 
14 using namespace shogun;
15 
17  : CSGObject()
18 {
19  SG_UNSTABLE("CDisjointSet::CDisjointSet()", "\n");
20 
21  init();
22 }
23 
24 CDisjointSet::CDisjointSet(int32_t num_elements)
25  : CSGObject()
26 {
27  init();
28  m_num_elements = num_elements;
29  m_parent = SGVector<int32_t>(num_elements);
30  m_rank = SGVector<int32_t>(num_elements);
31 }
32 
33 void CDisjointSet::init()
34 {
35  SG_ADD(&m_num_elements, "num_elements", "Number of elements", MS_NOT_AVAILABLE);
36  SG_ADD(&m_parent, "parent", "Parent pointers", MS_NOT_AVAILABLE);
37  SG_ADD(&m_rank, "rank", "Rank of each element", MS_NOT_AVAILABLE);
38  SG_ADD(&m_is_connected, "is_connected", "Whether disjoint sets have been linked", MS_NOT_AVAILABLE);
39 
40  m_is_connected = false;
41  m_num_elements = -1;
42 }
43 
45 {
46  REQUIRE(m_num_elements > 0, "%s::make_sets(): m_num_elements <= 0.\n", get_name());
47 
48  m_parent.range_fill();
49  m_rank.zero();
50 }
51 
52 int32_t CDisjointSet::find_set(int32_t x)
53 {
54  ASSERT(x >= 0 && x < m_num_elements);
55 
56  // path compression
57  if (x != m_parent[x])
58  m_parent[x] = find_set(m_parent[x]);
59 
60  return m_parent[x];
61 }
62 
63 int32_t CDisjointSet::link_set(int32_t xroot, int32_t yroot)
64 {
65  ASSERT(xroot >= 0 && xroot < m_num_elements);
66  ASSERT(yroot >= 0 && yroot < m_num_elements);
67  ASSERT(m_parent[xroot] == xroot && m_parent[yroot] == yroot);
68  ASSERT(xroot != yroot);
69 
70  // union by rank
71  if (m_rank[xroot] > m_rank[yroot])
72  {
73  m_parent[yroot] = xroot;
74  return xroot;
75  }
76  else
77  {
78  m_parent[xroot] = yroot;
79  if (m_rank[xroot] == m_rank[yroot])
80  m_rank[yroot] += 1;
81 
82  return yroot;
83  }
84 }
85 
86 bool CDisjointSet::union_set(int32_t x, int32_t y)
87 {
88  ASSERT(x >= 0 && x < m_num_elements);
89  ASSERT(y >= 0 && y < m_num_elements);
90 
91  int32_t xroot = find_set(x);
92  int32_t yroot = find_set(y);
93 
94  if (xroot == yroot)
95  return true;
96 
97  link_set(xroot, yroot);
98  return false;
99 }
100 
101 bool CDisjointSet::is_same_set(int32_t x, int32_t y)
102 {
103  ASSERT(x >= 0 && x < m_num_elements);
104  ASSERT(y >= 0 && y < m_num_elements);
105 
106  if (find_set(x) == find_set(y))
107  return true;
108 
109  return false;
110 }
111 
113 {
114  REQUIRE(m_num_elements > 0, "%s::get_unique_labeling(): m_num_elements <= 0.\n", get_name());
115 
116  if (out_labels.size() != m_num_elements)
117  out_labels.resize_vector(m_num_elements);
118 
119  SGVector<int32_t> roots(m_num_elements);
120  SGVector<int32_t> flags(m_num_elements);
121  SGVector<int32_t>::fill_vector(flags.vector, flags.vlen, -1);
122  int32_t unilabel = 0;
123 
124  for (int32_t i = 0; i < m_num_elements; i++)
125  {
126  roots[i] = find_set(i);
127  // if roots[i] never be found
128  if (flags[roots[i]] < 0)
129  flags[roots[i]] = unilabel++;
130  }
131 
132  for (int32_t i = 0; i < m_num_elements; i++)
133  out_labels[i] = flags[roots[i]];
134 
135  return unilabel;
136 }
137 
139 {
140  REQUIRE(m_num_elements > 0, "%s::get_num_sets(): m_num_elements <= 0.\n", get_name());
141 
142  return get_unique_labeling(SGVector<int32_t>(m_num_elements));
143 }
144 
146 {
147  return m_is_connected;
148 }
149 
150 void CDisjointSet::set_connected(bool is_connected)
151 {
152  m_is_connected = is_connected;
153 }
154 
void range_fill(T start=0)
Definition: SGVector.cpp:171
virtual const char * get_name() const
Definition: DisjointSet.h:42
bool union_set(int32_t x, int32_t y)
Definition: DisjointSet.cpp:86
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:221
#define REQUIRE(x,...)
Definition: SGIO.h:206
int32_t size() const
Definition: SGVector.h:113
index_t vlen
Definition: SGVector.h:494
#define ASSERT(x)
Definition: SGIO.h:201
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
int32_t get_unique_labeling(SGVector< int32_t > out_labels)
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
void set_connected(bool is_connected)
bool is_same_set(int32_t x, int32_t y)
void resize_vector(int32_t n)
Definition: SGVector.cpp:257
int32_t link_set(int32_t xroot, int32_t yroot)
Definition: DisjointSet.cpp:63
int32_t find_set(int32_t x)
Definition: DisjointSet.cpp:52
#define SG_ADD(...)
Definition: SGObject.h:84
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:132

SHOGUN Machine Learning Toolbox - Documentation