Go to the documentation of this file.00001 #include <stdio.h>
00002 #include <string.h>
00003
00004 #include <shogun/mathematics/Math.h>
00005 #include <shogun/lib/config.h>
00006 #include <shogun/io/SGIO.h>
00007 #include <shogun/structure/SegmentLoss.h>
00008 #include <shogun/lib/Array.h>
00009 #include <shogun/lib/Array2.h>
00010 #include <shogun/lib/Array3.h>
00011 #include <shogun/base/SGObject.h>
00012
00013
00014 using namespace shogun;
00015
00016 CSegmentLoss::CSegmentLoss()
00017 :CSGObject(),
00018 m_segment_loss_matrix(1,1),
00019 m_segment_loss(1,1,2),
00020 m_segment_ids(NULL),
00021 m_segment_mask(NULL),
00022 m_num_segment_types(0)
00023 {
00024 }
00025 CSegmentLoss::~CSegmentLoss()
00026 {
00027 }
00028
00029 void CSegmentLoss::set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n)
00030 {
00031
00032 if (2*m!=n)
00033 SG_ERROR( "segment_loss should be 2 x quadratic matrix: %i!=%i\n", 2*m, n) ;
00034
00035 m_num_segment_types = m;
00036
00037 m_segment_loss.set_array(segment_loss, m, n/2, 2, true, true) ;
00038 }
00039
00040 void CSegmentLoss::set_segment_ids(CArray<int32_t>* segment_ids)
00041 {
00042 m_segment_ids = segment_ids;
00043 }
00044
00045 void CSegmentLoss::set_segment_mask(CArray<float64_t>* segment_mask)
00046 {
00047 m_segment_mask = segment_mask;
00048 }
00049
00050 void CSegmentLoss::compute_loss(int32_t* all_pos, int32_t len)
00051 {
00052 #ifdef DEBUG
00053 SG_PRINT("compute loss: len: %i, m_num_segment_types: %i\n", len, m_num_segment_types);
00054 SG_PRINT("m_segment_mask->element(0):%f \n", m_segment_mask->element(0));
00055 SG_PRINT("m_segment_ids->element(0):%i \n", m_segment_ids->element(0));
00056 #endif
00057 ASSERT(m_segment_ids->get_dim1()==len);
00058 ASSERT(m_segment_mask->get_dim1()==len);
00059
00060 m_segment_loss_matrix.resize_array(m_num_segment_types,len);
00061
00062 for (int seg_type=0; seg_type<m_num_segment_types; seg_type++)
00063 {
00064 float32_t value = 0;
00065 int32_t last_id = -1;
00066 int32_t last_pos = all_pos[len-1];
00067 for (int pos=len-1;pos>=0; pos--)
00068 {
00069 int32_t cur_id = m_segment_ids->element(pos);
00070 if (cur_id!=last_id)
00071 {
00072
00073 value += m_segment_mask->element(pos)*m_segment_loss.element(cur_id, seg_type, 0);
00074 last_id = cur_id;
00075 }
00076
00077 value += m_segment_mask->element(pos)*m_segment_loss.element(cur_id, seg_type, 1)*(last_pos-all_pos[pos]);
00078 last_pos = all_pos[pos];
00079 m_segment_loss_matrix.element(seg_type, pos)=value;
00080 }
00081 }
00082 #ifdef DEBUG
00083 m_segment_loss_matrix.display_array();
00084 #endif
00085 }
00086