00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Jonas Behr 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 #ifndef __SEGMENT_LOSS__ 00011 #define __SEGMENT_LOSS__ 00012 00013 #include <shogun/lib/common.h> 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/lib/Array.h> 00016 #include <shogun/lib/Array2.h> 00017 #include <shogun/lib/Array3.h> 00018 00019 00020 namespace shogun 00021 { 00022 template <class T> class CArray; 00023 template <class T> class CArray2; 00024 template <class T> class CArray3; 00026 class CSegmentLoss : public CSGObject 00027 { 00028 public: 00029 00032 CSegmentLoss(); 00033 00034 virtual ~CSegmentLoss(); 00035 00042 float32_t get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id); 00043 00050 float32_t get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id); 00051 00058 void set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n); 00059 00064 void set_segment_ids(CArray<int32_t>* segment_ids); 00065 00072 void set_segment_mask(CArray<float64_t>* segment_mask); 00073 00078 void set_num_segment_types(int32_t num_segment_types) 00079 { 00080 m_num_segment_types = num_segment_types; 00081 } 00082 00088 void compute_loss(int32_t* all_pos, int32_t len); 00089 00093 inline virtual const char* get_name() const { return "SegmentLoss"; } 00094 protected: 00095 00097 CArray2<float32_t> m_segment_loss_matrix; 00098 00103 CArray3<float64_t> m_segment_loss; 00104 00106 CArray<int32_t>* m_segment_ids; 00107 00109 CArray<float64_t>* m_segment_mask; 00110 00112 int32_t m_num_segment_types; 00113 }; 00114 00115 inline float32_t CSegmentLoss::get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id) 00116 { 00117 00118 /* int32_t from_pos_shift = from_pos ; 00119 if (print) 00120 SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00121 from_pos_shift, to_pos, segment_id, 00122 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00123 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00124 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00125 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ; 00126 while(1) 00127 { 00128 while (m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1) && from_pos_shift<to_pos) 00129 from_pos_shift++ ; 00130 if (print) 00131 SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00132 from_pos_shift, to_pos, segment_id, 00133 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00134 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00135 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00136 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ; 00137 00138 if (from_pos_shift>=to_pos) 00139 { 00140 //SG_PRINT("break") ; 00141 break ; 00142 } 00143 else from_pos_shift++ ; 00144 } 00145 if (print) 00146 SG_PRINT("break\n") ; */ 00147 00148 float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos)-m_segment_loss_matrix.element(segment_id, to_pos); 00149 diff_contrib += m_segment_mask->element(to_pos-1)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos-1), 0); 00150 return diff_contrib; 00151 } 00152 00153 inline float32_t CSegmentLoss::get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id) 00154 { 00155 int32_t from_pos_shift = from_pos ; 00156 00157 /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00158 segment_id, 00159 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00160 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00161 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00162 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/ 00163 00164 while (from_pos_shift<to_pos && m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1)) 00165 from_pos_shift++ ; 00166 00167 /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00168 segment_id, 00169 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00170 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00171 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00172 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/ 00173 00174 float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos); 00175 //diff_contrib += m_segment_mask->element(to_pos)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos), 0); 00176 00177 //if (from_pos_shift!=from_pos) 00178 // SG_PRINT("shifting from %i to %i, to_pos=%i, loss=%1.1f\n", from_pos, from_pos_shift, to_pos, diff_contrib) ; 00179 00180 return diff_contrib; 00181 } 00182 } 00183 #endif