SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MahalanobisDistance.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Fernando José Iglesias García
8  * Copyright (C) 2012 Fernando José Iglesias García
9  */
10 
11 #ifdef HAVE_LAPACK
12 
13 #include <shogun/lib/common.h>
14 #include <shogun/io/SGIO.h>
19 
20 using namespace shogun;
21 
23 {
24  init();
25 }
26 
28 : CRealDistance()
29 {
30  init();
31  init(l, r);
32 }
33 
35 {
36  cleanup();
37 }
38 
39 bool CMahalanobisDistance::init(CFeatures* l, CFeatures* r)
40 {
41  CRealDistance::init(l, r);
42 
43 
44  if ( l == r)
45  {
46  mean = ((CDenseFeatures<float64_t>*) l)->get_mean();
47  icov = ((CDenseFeatures<float64_t>*) l)->get_cov();
48  }
49  else
50  {
53  }
54 
56 
57  return true;
58 }
59 
61 {
62 }
63 
64 float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)
65 {
66 
68  get_feature_vector(idx_b);
69 
72 
73  if (use_mean)
74  diff = mean.clone();
75  else
76  {
77  avec = ((CDenseFeatures<float64_t>*) lhs)->get_feature_vector(idx_a);
78  diff=avec.clone();
79  }
80 
81  ASSERT(diff.vlen == bvec.vlen)
82 
83  for (int32_t i=0; i < diff.vlen; i++)
84  diff[i] = bvec.vector[i] - diff[i];
85 
86  SGVector<float64_t> v = diff.clone();
87  cblas_dgemv(CblasColMajor, CblasNoTrans,
89  diff.vlen, diff.vector, 1, 0.0, v.vector, 1);
90 
91  float64_t result = cblas_ddot(v.vlen, v.vector, 1, diff.vector, 1);
92 
93  if (!use_mean)
94  ((CDenseFeatures<float64_t>*) lhs)->free_feature_vector(avec, idx_a);
95 
96  ((CDenseFeatures<float64_t>*) rhs)->free_feature_vector(bvec, idx_b);
97 
98  if (disable_sqrt)
99  return result;
100  else
101  return CMath::sqrt(result);
102 }
103 
104 void CMahalanobisDistance::init()
105 {
106  disable_sqrt=false;
107  use_mean=false;
108 
109  m_parameters->add(&disable_sqrt, "disable_sqrt", "If sqrt shall not be applied.");
110  m_parameters->add(&use_mean, "use_mean", "If distance shall be computed between mean vector and vector from rhs or between lhs and rhs.");
111 }
112 
113 #endif /* HAVE_LAPACK */

SHOGUN Machine Learning Toolbox - Documentation