SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CustomDistance.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #ifndef _CUSTOMDISTANCE_H___
12 #define _CUSTOMDISTANCE_H___
13 
15 #include <shogun/lib/common.h>
18 
19 namespace shogun
20 {
30 {
31  public:
34 
41 
45  CCustomDistance(const SGMatrix<float64_t> distance_matrix);
46 
58  const float64_t* dm, int32_t rows, int32_t cols);
59 
71  const float32_t* dm, int32_t rows, int32_t cols);
72 
73  virtual ~CCustomDistance();
74 
85  virtual bool dummy_init(int32_t rows, int32_t cols);
86 
93  virtual bool init(CFeatures* l, CFeatures* r);
94 
96  virtual void cleanup();
97 
103 
108  virtual EFeatureType get_feature_type() { return F_ANY; }
109 
114  virtual EFeatureClass get_feature_class() { return C_ANY; }
115 
120  virtual const char* get_name() const { return "CustomDistance"; }
121 
133  const float64_t* dm, int32_t len)
134  {
136  }
137 
149  const float32_t* dm, int32_t len)
150  {
152  }
153 
164  template <class T>
166  const T* dm, int64_t len)
167  {
168  ASSERT(dm)
169  ASSERT(len>0)
170 
171  int64_t cols = (int64_t) floor(-0.5 + CMath::sqrt(0.25+2*len));
172 
173  int64_t int32_max=2147483647;
174 
175  if (cols> int32_max)
176  SG_ERROR("Matrix larger than %d x %d\n", int32_max)
177 
178  if (cols*(cols+1)/2 != len)
179  {
180  SG_ERROR("dm should be a vector containing a lower triangle matrix, with len=cols*(cols+1)/2 elements\n")
181  return false;
182  }
183 
184  cleanup_custom();
185  SG_DEBUG("using custom distance of size %dx%d\n", cols,cols)
186 
187  dmatrix= SG_MALLOC(float32_t, len);
188 
189  upper_diagonal=true;
190  num_rows=cols;
191  num_cols=cols;
192 
193  for (int64_t i=0; i<len; i++)
194  dmatrix[i]=dm[i];
195 
197  return true;
198  }
199 
211  const float64_t* dm, int32_t rows, int32_t cols)
212  {
213  return set_triangle_distance_matrix_from_full_generic(dm, rows, cols);
214  }
215 
227  const float32_t* dm, int32_t rows, int32_t cols)
228  {
229  return set_triangle_distance_matrix_from_full_generic(dm, rows, cols);
230  }
231 
240  template <class T>
242  const T* dm, int32_t rows, int32_t cols)
243  {
244  ASSERT(rows==cols)
245 
246  cleanup_custom();
247  SG_DEBUG("using custom distance of size %dx%d\n", cols,cols)
248 
249  dmatrix= SG_MALLOC(float32_t, int64_t(cols)*(cols+1)/2);
250 
251  upper_diagonal=true;
252  num_rows=cols;
253  num_cols=cols;
254 
255  for (int64_t row=0; row<num_rows; row++)
256  {
257  for (int64_t col=row; col<num_cols; col++)
258  {
259  int64_t idx=row * num_cols - row*(row+1)/2 + col;
260  dmatrix[idx]= (float32_t) dm[col*num_rows+row];
261  }
262  }
263  dummy_init(rows, cols);
264  return true;
265  }
266 
277  const float64_t* dm, int32_t rows, int32_t cols)
278  {
279  return set_full_distance_matrix_from_full_generic(dm, rows, cols);
280  }
281 
292  const float32_t* dm, int32_t rows, int32_t cols)
293  {
294  return set_full_distance_matrix_from_full_generic(dm, rows, cols);
295  }
296 
304  template <class T>
305  bool set_full_distance_matrix_from_full_generic(const T* dm, int32_t rows, int32_t cols)
306  {
307  cleanup_custom();
308  SG_DEBUG("using custom distance of size %dx%d\n", rows,cols)
309 
310  dmatrix=SG_MALLOC(float32_t, rows*cols);
311 
312  upper_diagonal=false;
313  num_rows=rows;
314  num_cols=cols;
315 
316  for (int32_t row=0; row<num_rows; row++)
317  {
318  for (int32_t col=0; col<num_cols; col++)
319  {
320  dmatrix[row * num_cols + col]=dm[col*num_rows+row];
321  }
322  }
323 
324  dummy_init(rows, cols);
325  return true;
326  }
327 
332  virtual int32_t get_num_vec_lhs()
333  {
334  return num_rows;
335  }
336 
341  virtual int32_t get_num_vec_rhs()
342  {
343  return num_cols;
344  }
345 
350  virtual bool has_features()
351  {
352  return (num_rows>0) && (num_cols>0);
353  }
354 
355  protected:
362  virtual float64_t compute(int32_t row, int32_t col);
363 
364  private:
365  void init();
366 
368  void cleanup_custom();
369 
370  protected:
374  int32_t num_rows;
376  int32_t num_cols;
379 };
380 
381 }
382 #endif /* _CUSTOMKERNEL_H__ */

SHOGUN Machine Learning Toolbox - Documentation