SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SegmentLoss.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2009 Jonas Behr
8  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 #ifndef __SEGMENT_LOSS__
11 #define __SEGMENT_LOSS__
12 
13 #include <shogun/lib/common.h>
14 #include <shogun/base/SGObject.h>
16 
17 
18 namespace shogun
19 {
20  template <class T> class CDynamicArray;
22 class CSegmentLoss : public CSGObject
23 {
24  public:
25 
28  CSegmentLoss();
29 
30  virtual ~CSegmentLoss();
31 
38  float32_t get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id);
39 
46  float32_t get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id);
47 
54  void set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n);
55 
60  void set_segment_ids(CDynamicArray<int32_t>* segment_ids);
61 
68  void set_segment_mask(CDynamicArray<float64_t>* segment_mask);
69 
74  void set_num_segment_types(int32_t num_segment_types)
75  {
76  m_num_segment_types = num_segment_types;
77  }
78 
84  void compute_loss(int32_t* all_pos, int32_t len);
85 
89  inline virtual const char* get_name() const { return "SegmentLoss"; }
90  protected:
91 
94 
100 
103 
106 
109 };
110 
111 inline float32_t CSegmentLoss::get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id)
112 {
113 
114  /* int32_t from_pos_shift = from_pos ;
115  if (print)
116  SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
117  from_pos_shift, to_pos, segment_id,
118  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
119  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
120  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
121  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
122  while(1)
123  {
124  while (m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1) && from_pos_shift<to_pos)
125  from_pos_shift++ ;
126  if (print)
127  SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
128  from_pos_shift, to_pos, segment_id,
129  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
130  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
131  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
132  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
133 
134  if (from_pos_shift>=to_pos)
135  {
136  //SG_PRINT("break") ;
137  break ;
138  }
139  else from_pos_shift++ ;
140  }
141  if (print)
142  SG_PRINT("break\n") ; */
143 
144  float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos)-m_segment_loss_matrix.element(segment_id, to_pos);
145  diff_contrib += m_segment_mask->element(to_pos-1)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos-1), 0);
146  return diff_contrib;
147 }
148 
149 inline float32_t CSegmentLoss::get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id)
150 {
151  int32_t from_pos_shift = from_pos ;
152 
153  /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
154  segment_id,
155  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
156  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
157  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
158  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
159 
160  while (from_pos_shift<to_pos && m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1))
161  from_pos_shift++ ;
162 
163  /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
164  segment_id,
165  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
166  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
167  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
168  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
169 
170  float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos);
171  //diff_contrib += m_segment_mask->element(to_pos)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos), 0);
172 
173  //if (from_pos_shift!=from_pos)
174  // SG_PRINT("shifting from %i to %i, to_pos=%i, loss=%1.1f\n", from_pos, from_pos_shift, to_pos, diff_contrib) ;
175 
176  return diff_contrib;
177 }
178 }
179 #endif

SHOGUN Machine Learning Toolbox - Documentation