SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
HierarchicalMultilabelModel.cpp
浏览该文件的文档.
1 /*
2  * This software is distributed under BSD 3-clause license (see LICENSE file).
3  *
4  * Copyright(C) 2014 Abinash Panda
5  * Copyright(C) 2014 Thoralf Klein
6  * Written(W) 2014 Abinash Panda
7  */
8 
14 
15 using namespace shogun;
16 
19 {
20  init(SGVector<int32_t>(0), false);
21 }
22 
24  CStructuredLabels * labels, SGVector<int32_t> taxonomy,
25  bool leaf_nodes_mandatory)
26  : CStructuredModel(features, labels)
27 {
28  init(taxonomy, leaf_nodes_mandatory);
29 }
30 
32  int32_t num_labels)
33 {
34  return new CMultilabelSOLabels(num_labels, m_num_classes);
35 }
36 
38 {
39  SG_FREE(m_children);
40 }
41 
42 void CHierarchicalMultilabelModel::init(SGVector<int32_t> taxonomy,
43  bool leaf_nodes_mandatory)
44 {
45  SG_ADD(&m_num_classes, "num_classes", "Number of (binary) class assignment per label",
47  SG_ADD(&m_taxonomy, "taxonomy", "Taxonomy of the hierarchy of the labels",
49  SG_ADD(&m_leaf_nodes_mandatory, "leaf_nodes_mandatory", "Whether internal nodes belong"
50  "to output class or not", MS_NOT_AVAILABLE);
51  SG_ADD(&m_root, "root", "Node-id of the ROOT element", MS_NOT_AVAILABLE);
52 
53  m_leaf_nodes_mandatory = leaf_nodes_mandatory;
54  m_num_classes = 0;
55 
56  int32_t num_classes = 0;
57 
58  if (m_labels)
59  {
60  num_classes = ((CMultilabelSOLabels *)m_labels)->get_num_classes();
61  }
62 
63  REQUIRE(num_classes == taxonomy.vlen, "Number of classes must be equal to taxonomy vector = %d\n",
64  taxonomy.vlen);
65 
66  m_taxonomy = SGVector<int32_t>(num_classes);
67 
68  m_root = -1;
69  int32_t root_node_count = 0;
70 
71  for (index_t i = 0; i < num_classes; i++)
72  {
73  REQUIRE(taxonomy[i] < num_classes && taxonomy[i] >= -1, "parent-id of node-id:%d is taxonomy[%d] = %d,"
74  " but must be within [-1; %d-1] (-1 for root node)\n", i, i,
75  taxonomy[i], num_classes);
76  m_taxonomy[i] = taxonomy[i];
77 
78  if (m_taxonomy[i] == -1)
79  {
80  m_root = i;
81  root_node_count++;
82  }
83  }
84 
85  if (num_classes)
86  {
87  REQUIRE(m_root != -1 && root_node_count == 1, "Single ROOT element must be defined "
88  "with parent-id = -1\n");
89  }
90 
91  // storing all the children of all the nodes in form of array of vectors
92  m_children = SG_MALLOC(SGVector<int32_t>, num_classes);
93 
94  for (int32_t i = 0; i < num_classes; i++)
95  {
96  SGVector<int32_t> child_id = m_taxonomy.find(i);
97  m_children[i] = child_id;
98  }
99 
100 }
101 
103 {
104  int32_t num_classes = ((CMultilabelSOLabels *)m_labels)->get_num_classes();
105  int32_t feats_dim = ((CDotFeatures *)m_features)->get_dim_feature_space();
106 
107  return num_classes * feats_dim;
108 }
109 
110 SGVector<int32_t> CHierarchicalMultilabelModel::get_label_vector(
111  SGVector<int32_t> sparse_label)
112 {
113  int32_t num_classes = ((CMultilabelSOLabels *)m_labels)->get_num_classes();
114 
115  SGVector<int32_t> label_vector(num_classes);
116  label_vector.zero();
117 
118  for (index_t i = 0; i < sparse_label.vlen; i++)
119  {
120  int32_t node_id = sparse_label[i];
121  label_vector[node_id] = 1;
122 
123  for (int32_t parent_id = m_taxonomy[node_id]; parent_id != -1;
124  parent_id = m_taxonomy[parent_id])
125  {
126  label_vector[parent_id] = 1;
127  }
128 
129  }
130 
131  return label_vector;
132 }
133 
135  int32_t feat_idx, CStructuredData * y)
136 {
138  SGVector<int32_t> slabel_data = slabel->get_data();
139  SGVector<int32_t> label_vector = get_label_vector(slabel_data);
140 
142  psi.zero();
143 
144  CDotFeatures * dot_feats = (CDotFeatures *)m_features;
145  SGVector<float64_t> x = dot_feats->get_computed_dot_feature_vector(feat_idx);
146  int32_t feats_dim = dot_feats->get_dim_feature_space();
147 
148  for (index_t i = 0; i < label_vector.vlen; i++)
149  {
150  int32_t label = label_vector[i];
151 
152  if (label)
153  {
154  int32_t offset = i * feats_dim;
155 
156  for (index_t j = 0; j < feats_dim; j++)
157  {
158  psi[offset + j] = x[j];
159  }
160  }
161  }
162 
163  return psi;
164 }
165 
167  CStructuredData * y2)
168 {
171 
172  ASSERT(y1_slabel != NULL);
173  ASSERT(y2_slabel != NULL);
174 
175  return delta_loss(get_label_vector(y1_slabel->get_data()),
176  get_label_vector(y2_slabel->get_data()));
177 }
178 
181 {
182  REQUIRE(y1.vlen == y2.vlen, "Size of both the vectors should be same\n");
183 
184  float64_t loss = 0;
185 
186  for (index_t i = 0; i < y1.vlen; i++)
187  {
188  loss += delta_loss(y1[i], y2[i]);
189  }
190 
191  return loss;
192 }
193 
195 {
196  return y1 != y2 ? 1 : 0;
197 }
198 
200  float64_t regularization,
208 {
210 }
211 
213  int32_t feat_idx, bool const training)
214 {
215  CDotFeatures * dot_feats = (CDotFeatures *)m_features;
216  int32_t feats_dim = dot_feats->get_dim_feature_space();
217 
219 
220  if (training)
221  {
222  m_num_classes = multi_labs->get_num_classes();
223  }
224 
225  REQUIRE(m_num_classes > 0, "The model needs to be trained before using "
226  "if for prediction\n");
227 
228  int32_t dim = get_dim();
229  ASSERT(dim == w.vlen);
230 
231  // nodes_to_traverse is a dynamic list which keep tracks of which nodes to
232  // traverse
233  CDynamicArray<int32_t> * nodes_to_traverse = new CDynamicArray<int32_t>();
234  SG_REF(nodes_to_traverse);
235 
236  // start traversing with the root node
237  // insertion of node at the back end
238  nodes_to_traverse->push_back(m_root);
239 
240  SGVector<int32_t> y_pred_sparse(m_num_classes);
241  int32_t count = 0;
242 
243  while (nodes_to_traverse->get_num_elements())
244  {
245  // extraction of the node at the front end
246  int32_t node = nodes_to_traverse->get_element(0);
247  nodes_to_traverse->delete_element(0);
248 
249  float64_t score = dot_feats->dense_dot(feat_idx, w.vector + node * feats_dim,
250  feats_dim);
251 
252  // if the score is greater than zero, then all the children nodes are
253  // to be traversed next
254  if (score > 0)
255  {
256  SGVector<int32_t> child_id = m_children[node];
257 
258  // inserting the children nodes at the back end
259  for (index_t i = 0; i < child_id.vlen; i++)
260  {
261  nodes_to_traverse->push_back(child_id[i]);
262  }
263 
264  if (m_leaf_nodes_mandatory)
265  {
266  // terminal nodes don't have any children
267  if (child_id.vlen == 0)
268  {
269  y_pred_sparse[count++] = node;
270  }
271  }
272  else
273  {
274  y_pred_sparse[count++] = node;
275  }
276  }
277  }
278 
279  y_pred_sparse.resize_vector(count);
280 
281  CResultSet * ret = new CResultSet();
282  SG_REF(ret);
283  ret->psi_computed = true;
284 
285  CSparseMultilabel * y_pred = new CSparseMultilabel(y_pred_sparse);
286  SG_REF(y_pred);
287 
288  ret->psi_pred = get_joint_feature_vector(feat_idx, y_pred);
289  ret->score = CMath::dot(w.vector, ret->psi_pred.vector, dim);
290  ret->argmax = y_pred;
291 
292  if (training)
293  {
294  ret->delta = CStructuredModel::delta_loss(feat_idx, y_pred);
296  feat_idx);
297  ret->score += (ret->delta - CMath::dot(w.vector,
298  ret->psi_truth.vector, dim));
299  }
300 
301  SG_UNREF(nodes_to_traverse);
302 
303  return ret;
304 }
305 
SGVector< float64_t > psi_truth
Base class of the labels used in Structured Output (SO) problems.
int32_t index_t
Definition: common.h:62
Class CMultilabelSOLabels used in the application of Structured Output (SO) learning to Multilabel Cl...
virtual float64_t dense_dot(int32_t vec_idx1, const float64_t *vec2, int32_t vec2_len)=0
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual CStructuredLabels * structured_labels_factory(int32_t num_labels=0)
SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx)
Features that support dot products among other operations.
Definition: DotFeatures.h:44
#define SG_REF(x)
Definition: SGObject.h:51
virtual int32_t get_dim_feature_space() const =0
virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData *y)
index_t vlen
Definition: SGVector.h:494
#define ASSERT(x)
Definition: SGIO.h:201
bool delete_element(int32_t idx)
Definition: DynamicArray.h:375
virtual void init_primal_opt(float64_t regularization, SGMatrix< float64_t > &A, SGVector< float64_t > a, SGMatrix< float64_t > B, SGVector< float64_t > &b, SGVector< float64_t > &lb, SGVector< float64_t > &ub, SGMatrix< float64_t > &C)
double float64_t
Definition: common.h:50
float64_t delta_loss(int32_t ytrue_idx, CStructuredData *ypred)
Class CSparseMultilabel to be used in the application of Structured Output (SO) learning to Multilabe...
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
Definition: Math.h:627
Class CStructuredModel that represents the application specific model and contains most of the applic...
#define SG_UNREF(x)
Definition: SGObject.h:52
CStructuredLabels * m_labels
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)
virtual float64_t delta_loss(CStructuredData *y1, CStructuredData *y2)
The class Features is the base class of all feature objects.
Definition: Features.h:68
CStructuredData * argmax
SGVector< float64_t > get_computed_dot_feature_vector(int32_t num)
int32_t get_num_elements() const
Definition: DynamicArray.h:200
SGVector< float64_t > psi_pred
void resize_vector(int32_t n)
Definition: SGVector.cpp:259
#define SG_ADD(...)
Definition: SGObject.h:81
virtual int32_t get_num_classes() const
static SGMatrix< T > create_identity_matrix(index_t size, T scale)
SGVector< index_t > find(T elem)
Definition: SGVector.cpp:809
Base class of the components of StructuredLabels.
const T & get_element(int32_t idx1, int32_t idx2=0, int32_t idx3=0) const
Definition: DynamicArray.h:212
static CSparseMultilabel * obtain_from_generic(CStructuredData *base_data)
SGVector< int32_t > get_data() const

SHOGUN 机器学习工具包 - 项目文档