SegmentLoss.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Jonas Behr
00008  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 #ifndef __SEGMENT_LOSS__
00011 #define __SEGMENT_LOSS__
00012 
00013 #include <shogun/lib/common.h>
00014 #include <shogun/base/SGObject.h>
00015 #include <shogun/lib/DynamicArray.h>
00016 
00017 
00018 namespace shogun
00019 {
00020     template <class T> class CDynamicArray;
00022 class CSegmentLoss : public CSGObject
00023 {
00024         public:
00025 
00028         CSegmentLoss();
00029 
00030         virtual ~CSegmentLoss();
00031 
00038         float32_t get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id);
00039 
00046         float32_t get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id);
00047 
00054         void set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n);
00055 
00060         void set_segment_ids(CDynamicArray<int32_t>* segment_ids);
00061 
00068         void set_segment_mask(CDynamicArray<float64_t>* segment_mask);
00069 
00074         void set_num_segment_types(int32_t num_segment_types)
00075         {
00076             m_num_segment_types = num_segment_types;
00077         }
00078 
00084         void compute_loss(int32_t* all_pos, int32_t len);
00085 
00089         virtual const char* get_name() const { return "SegmentLoss"; }
00090     protected:
00091 
00093         CDynamicArray<float32_t> m_segment_loss_matrix; // 2d
00094 
00099         CDynamicArray<float64_t> m_segment_loss; // 3d
00100 
00102         CDynamicArray<int32_t>* m_segment_ids;
00103 
00105         CDynamicArray<float64_t>* m_segment_mask;
00106 
00108         int32_t m_num_segment_types;
00109 };
00110 
00111 inline float32_t CSegmentLoss::get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id)
00112 {
00113 
00114     /*  int32_t from_pos_shift = from_pos ;
00115         if (print)
00116         SG_PRINT("# pos=%i,%i  segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
00117                  from_pos_shift, to_pos, segment_id,
00118                  m_segment_ids->element(from_pos_shift-2),  m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
00119                  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
00120                  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
00121                  m_segment_ids->element(from_pos_shift+1),  m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
00122     while(1)
00123     {
00124         while (m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1) && from_pos_shift<to_pos)
00125             from_pos_shift++ ;
00126         if (print)
00127             SG_PRINT("# pos=%i,%i  segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
00128                      from_pos_shift, to_pos, segment_id,
00129                      m_segment_ids->element(from_pos_shift-2),  m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
00130                      m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
00131                      m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
00132                      m_segment_ids->element(from_pos_shift+1),  m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
00133 
00134         if (from_pos_shift>=to_pos)
00135         {
00136             //SG_PRINT("break") ;
00137             break ;
00138         }
00139         else from_pos_shift++ ;
00140         }
00141     if (print)
00142     SG_PRINT("break\n") ; */
00143 
00144     float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos)-m_segment_loss_matrix.element(segment_id, to_pos);
00145     diff_contrib += m_segment_mask->element(to_pos-1)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos-1), 0);
00146     return diff_contrib;
00147 }
00148 
00149 inline float32_t CSegmentLoss::get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id)
00150 {
00151     int32_t from_pos_shift = from_pos ;
00152 
00153     /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
00154              segment_id,
00155              m_segment_ids->element(from_pos_shift-2),  m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
00156              m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
00157              m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
00158              m_segment_ids->element(from_pos_shift+1),  m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
00159 
00160     while (from_pos_shift<to_pos && m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1))
00161         from_pos_shift++ ;
00162 
00163     /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
00164              segment_id,
00165              m_segment_ids->element(from_pos_shift-2),  m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
00166              m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
00167              m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
00168              m_segment_ids->element(from_pos_shift+1),  m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
00169 
00170     float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos);
00171     //diff_contrib += m_segment_mask->element(to_pos)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos), 0);
00172 
00173     //if (from_pos_shift!=from_pos)
00174     //  SG_PRINT("shifting from %i to %i, to_pos=%i, loss=%1.1f\n", from_pos, from_pos_shift, to_pos, diff_contrib) ;
00175 
00176     return diff_contrib;
00177 }
00178 }
00179 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation