SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
LinearARDKernel.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2015 Wu Lin
8  * Written (W) 2012 Jacob Walker
9  *
10  * Adapted from WeightedDegreeRBFKernel.cpp
11  */
12 
15 
16 #ifdef HAVE_LINALG_LIB
18 #endif
19 
20 using namespace shogun;
21 
23 {
24  initialize();
25 }
26 
28 {
30 }
31 
32 void CLinearARDKernel::initialize()
33 {
36  m_weights.set_const(1.0);
37  SG_ADD(&m_weights, "weights", "Feature weights", MS_AVAILABLE,
39  SG_ADD((int *)(&m_ARD_type), "type", "ARD kernel type", MS_NOT_AVAILABLE);
40 }
41 
43 {
44  REQUIRE(hs, "Features not set!\n");
45  CDenseFeatures<float64_t> * dense_hs=dynamic_cast<CDenseFeatures<float64_t> *>(hs);
46  if (dense_hs)
47  return dense_hs->get_feature_vector(idx);
48 
49  CDotFeatures * dot_hs=dynamic_cast<CDotFeatures *>(hs);
50  REQUIRE(dot_hs, "Kernel only supports DotFeatures\n");
51  return dot_hs->get_computed_dot_feature_vector(idx);
52 
53 }
54 
55 #ifdef HAVE_LINALG_LIB
57 {
58  initialize();
59 }
60 
62  CDotFeatures* r, int32_t size) : CDotKernel(size)
63 {
64  initialize();
65  init(l,r);
66 }
67 
68 bool CLinearARDKernel::init(CFeatures* l, CFeatures* r)
69 {
70  cleanup();
71  CDotKernel::init(l, r);
72  int32_t dim=((CDotFeatures*) l)->get_dim_feature_space();
73  if (m_ARD_type==KT_FULL)
74  {
75  REQUIRE(m_weights.num_cols==dim, "Dimension mismatch between features (%d) and weights (%d)\n",
76  dim, m_weights.num_cols);
77  }
78  else if (m_ARD_type==KT_DIAG)
79  {
80  REQUIRE(m_weights.num_rows==dim, "Dimension mismatch between features (%d) and weights (%d)\n",
81  dim, m_weights.num_rows);
82  }
83  return init_normalizer();
84 }
85 
86 
87 SGMatrix<float64_t> CLinearARDKernel::compute_right_product(SGVector<float64_t>right_vec,
88  float64_t & scalar_weight)
89 {
90  SGMatrix<float64_t> right;
91 
92  if (m_ARD_type==KT_SCALAR)
93  {
94  right=SGMatrix<float64_t>(right_vec.vector,right_vec.vlen,1,false);
95  scalar_weight*=m_weights[0];
96  }
97  else
98  {
99  SGMatrix<float64_t> rtmp(right_vec.vector,right_vec.vlen,1,false);
100 
101  if(m_ARD_type==KT_DIAG)
102  right=linalg::elementwise_product(m_weights, rtmp);
103  else if(m_ARD_type==KT_FULL)
104  right=linalg::matrix_product(m_weights, rtmp);
105  else
106  SG_ERROR("Unsupported ARD type\n");
107  }
108  return right;
109 }
110 
111 float64_t CLinearARDKernel::compute_helper(SGVector<float64_t> avec, SGVector<float64_t>bvec)
112 {
113  SGMatrix<float64_t> left;
114  SGMatrix<float64_t> left_transpose;
115  float64_t scalar_weight=1.0;
116  if (m_ARD_type==KT_SCALAR)
117  {
118  left=SGMatrix<float64_t>(avec.vector,1,avec.vlen,false);
119  scalar_weight=m_weights[0];
120  }
121  else
122  {
123  SGMatrix<float64_t> ltmp(avec.vector,avec.vlen,1,false);
124  if(m_ARD_type==KT_DIAG)
125  left_transpose=linalg::elementwise_product(m_weights, ltmp);
126  else if(m_ARD_type==KT_FULL)
127  left_transpose=linalg::matrix_product(m_weights, ltmp);
128  else
129  SG_ERROR("Unsupported ARD type\n");
130  left=SGMatrix<float64_t>(left_transpose.matrix,1,left_transpose.num_rows,false);
131  }
132  SGMatrix<float64_t> right=compute_right_product(bvec, scalar_weight);
133  SGMatrix<float64_t> res=linalg::matrix_product(left, right);
134  return res[0]*scalar_weight;
135 }
136 
137 float64_t CLinearARDKernel::compute_gradient_helper(SGVector<float64_t> avec,
139 {
140  float64_t result=0.0;
141 
142  if(m_ARD_type==KT_DIAG)
143  {
144  result=2.0*avec[index]*bvec[index]*m_weights[index];
145  }
146  else
147  {
148  SGMatrix<float64_t> left(avec.vector,1,avec.vlen,false);
149  SGMatrix<float64_t> right(bvec.vector,bvec.vlen,1,false);
151 
152  if (m_ARD_type==KT_SCALAR)
153  {
154  res=linalg::matrix_product(left, right);
155  result=2.0*res[0]*m_weights[0];
156  }
157  else if(m_ARD_type==KT_FULL)
158  {
159  int32_t row_index=index%m_weights.num_rows;
160  int32_t col_index=index/m_weights.num_rows;
161  //index is a linearized index of m_weights (column-major)
162  //m_weights is a d-by-p matrix, where p is #dimension of features
163  SGVector<float64_t> row_vec=m_weights.get_row_vector(row_index);
164  SGMatrix<float64_t> row_vec_r(row_vec.vector,row_vec.vlen,1,false);
165 
166  res=linalg::matrix_product(left, row_vec_r);
167  result=res[0]*bvec[col_index];
168 
169  SGMatrix<float64_t> row_vec_l(row_vec.vector,1,row_vec.vlen,false);
170  res=linalg::matrix_product(row_vec_l, right);
171  result+=res[0]*avec[col_index];
172 
173  }
174  else
175  {
176  SG_ERROR("Unsupported ARD type\n");
177  }
178 
179  }
180  return result*scale;
181 }
182 
183 void CLinearARDKernel::check_weight_gradient_index(index_t index)
184 {
185  REQUIRE(lhs, "Left features not set!\n");
186  REQUIRE(rhs, "Right features not set!\n");
187 
188  int32_t row_index, col_index;
189  if (m_ARD_type!=KT_SCALAR)
190  {
191  REQUIRE(index>=0, "Index (%d) must be non-negative\n",index);
192  if (m_ARD_type==KT_DIAG)
193  {
194  REQUIRE(index<m_weights.num_rows, "Index (%d) must be within #dimension of weights (%d)\n",
195  index, m_weights.num_rows);
196  }
197  else if(m_ARD_type==KT_FULL)
198  {
199  row_index=index%m_weights.num_rows;
200  col_index=index/m_weights.num_rows;
201  REQUIRE(row_index<m_weights.num_rows,
202  "Row index (%d) must be within #row of weights (%d)\n",
203  row_index, m_weights.num_rows);
204  REQUIRE(col_index<m_weights.num_cols,
205  "Column index (%d) must be within #column of weights (%d)\n",
206  col_index, m_weights.num_cols);
207  }
208  }
209 }
210 
211 SGMatrix<float64_t> CLinearARDKernel::get_weights()
212 {
214 }
215 
216 void CLinearARDKernel::set_weights(SGMatrix<float64_t> weights)
217 {
218  REQUIRE(weights.num_cols>0 && weights.num_rows>0,
219  "Weight Matrix (%d-by-%d) must not be empty\n",
220  weights.num_rows, weights.num_cols);
221  if (weights.num_cols>1)
223  else
224  {
225  if (weights.num_rows==1)
227  else
229  }
230  m_weights=weights;
231 }
232 
233 void CLinearARDKernel::set_scalar_weights(float64_t weight)
234 {
235  REQUIRE(weight>0, "Scalar (%f) weight should be positive\n",weight);
236  SGMatrix<float64_t> weights(1,1);
237  weights(0,0)=weight;
238  set_weights(weights);
239 }
240 
241 void CLinearARDKernel::set_vector_weights(SGVector<float64_t> weights)
242 {
243  SGMatrix<float64_t> weights_mat(weights.vlen,1);
244  std::copy(weights.vector, weights.vector+weights.vlen, weights_mat.matrix);
245  set_weights(weights_mat);
246 }
247 
248 void CLinearARDKernel::set_matrix_weights(SGMatrix<float64_t> weights)
249 {
250  set_weights(weights);
251 }
252 
253 #endif //HAVE_LINALG_LIB
virtual void cleanup()
Definition: Kernel.cpp:162
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
int32_t index_t
Definition: common.h:62
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
index_t num_cols
Definition: SGMatrix.h:378
SGMatrix< float64_t > m_weights
virtual SGVector< float64_t > get_feature_vector(int32_t idx, CFeatures *hs)
Features that support dot products among other operations.
Definition: DotFeatures.h:44
index_t num_rows
Definition: SGMatrix.h:376
Template class DotKernel is the base class for kernels working on DotFeatures.
Definition: DotKernel.h:31
index_t vlen
Definition: SGVector.h:494
double float64_t
Definition: common.h:50
virtual bool init_normalizer()
Definition: Kernel.cpp:157
CFeatures * rhs
feature vectors to occur on right hand side
Definition: Kernel.h:1061
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
CFeatures * lhs
feature vectors to occur on left hand side
Definition: Kernel.h:1059
The class Features is the base class of all feature objects.
Definition: Features.h:68
void scale(Matrix A, Matrix B, typename Matrix::Scalar alpha)
Definition: Core.h:93
SGVector< float64_t > get_computed_dot_feature_vector(int32_t num)
#define SG_ADD(...)
Definition: SGObject.h:81
SGVector< T > get_row_vector(index_t row) const
Definition: SGMatrix.cpp:1088
void set_const(T const_elem)
Definition: SGMatrix.cpp:133

SHOGUN Machine Learning Toolbox - Documentation