00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Jonas Behr 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 #ifndef __SEGMENT_LOSS__ 00011 #define __SEGMENT_LOSS__ 00012 00013 #include <shogun/lib/common.h> 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/lib/DynamicArray.h> 00016 00017 00018 namespace shogun 00019 { 00020 template <class T> class CDynamicArray; 00022 class CSegmentLoss : public CSGObject 00023 { 00024 public: 00025 00028 CSegmentLoss(); 00029 00030 virtual ~CSegmentLoss(); 00031 00038 float32_t get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id); 00039 00046 float32_t get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id); 00047 00054 void set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n); 00055 00060 void set_segment_ids(CDynamicArray<int32_t>* segment_ids); 00061 00068 void set_segment_mask(CDynamicArray<float64_t>* segment_mask); 00069 00074 void set_num_segment_types(int32_t num_segment_types) 00075 { 00076 m_num_segment_types = num_segment_types; 00077 } 00078 00084 void compute_loss(int32_t* all_pos, int32_t len); 00085 00089 virtual const char* get_name() const { return "SegmentLoss"; } 00090 protected: 00091 00093 CDynamicArray<float32_t> m_segment_loss_matrix; // 2d 00094 00099 CDynamicArray<float64_t> m_segment_loss; // 3d 00100 00102 CDynamicArray<int32_t>* m_segment_ids; 00103 00105 CDynamicArray<float64_t>* m_segment_mask; 00106 00108 int32_t m_num_segment_types; 00109 }; 00110 00111 inline float32_t CSegmentLoss::get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id) 00112 { 00113 00114 /* int32_t from_pos_shift = from_pos ; 00115 if (print) 00116 SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00117 from_pos_shift, to_pos, segment_id, 00118 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00119 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00120 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00121 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ; 00122 while(1) 00123 { 00124 while (m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1) && from_pos_shift<to_pos) 00125 from_pos_shift++ ; 00126 if (print) 00127 SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00128 from_pos_shift, to_pos, segment_id, 00129 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00130 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00131 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00132 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ; 00133 00134 if (from_pos_shift>=to_pos) 00135 { 00136 //SG_PRINT("break") ; 00137 break ; 00138 } 00139 else from_pos_shift++ ; 00140 } 00141 if (print) 00142 SG_PRINT("break\n") ; */ 00143 00144 float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos)-m_segment_loss_matrix.element(segment_id, to_pos); 00145 diff_contrib += m_segment_mask->element(to_pos-1)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos-1), 0); 00146 return diff_contrib; 00147 } 00148 00149 inline float32_t CSegmentLoss::get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id) 00150 { 00151 int32_t from_pos_shift = from_pos ; 00152 00153 /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00154 segment_id, 00155 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00156 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00157 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00158 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/ 00159 00160 while (from_pos_shift<to_pos && m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1)) 00161 from_pos_shift++ ; 00162 00163 /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n", 00164 segment_id, 00165 m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos), 00166 m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos), 00167 m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos), 00168 m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/ 00169 00170 float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos); 00171 //diff_contrib += m_segment_mask->element(to_pos)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos), 0); 00172 00173 //if (from_pos_shift!=from_pos) 00174 // SG_PRINT("shifting from %i to %i, to_pos=%i, loss=%1.1f\n", from_pos, from_pos_shift, to_pos, diff_contrib) ; 00175 00176 return diff_contrib; 00177 } 00178 } 00179 #endif