Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __SGSPARSEVECTOR_H__
00014 #define __SGSPARSEVECTOR_H__
00015
00016 #include <shogun/lib/config.h>
00017 #include <shogun/lib/DataType.h>
00018 #include <shogun/lib/SGReferencedData.h>
00019 #include <map>
00020
00021 namespace shogun
00022 {
00023
00024
00026 template <class T> struct SGSparseVectorEntry
00027 {
00029 index_t feat_index;
00031 T entry;
00032 };
00033
00039 template <class T> class SGSparseVector : public SGReferencedData
00040 {
00041 public:
00043 SGSparseVector() : SGReferencedData()
00044 {
00045 init_data();
00046 }
00047
00054 SGSparseVector(SGSparseVectorEntry<T>* feats, index_t num_entries,
00055 bool ref_counting=true) :
00056 SGReferencedData(ref_counting),
00057 num_feat_entries(num_entries), features(feats)
00058 {
00059 }
00060
00062 SGSparseVector(index_t num_entries, bool ref_counting=true) :
00063 SGReferencedData(ref_counting),
00064 num_feat_entries(num_entries)
00065 {
00066 features = SG_MALLOC(SGSparseVectorEntry<T>, num_feat_entries);
00067 }
00068
00070 SGSparseVector(const SGSparseVector& orig) :
00071 SGReferencedData(orig)
00072 {
00073 copy_data(orig);
00074 }
00075
00076 virtual ~SGSparseVector()
00077 {
00078 unref();
00079 }
00080
00092 T dense_dot(T alpha, T* vec, int32_t dim, T b)
00093 {
00094 ASSERT(vec);
00095 T result=b;
00096
00097 if (features)
00098 {
00099 for (int32_t i=0; i<num_feat_entries; i++)
00100 {
00101 result+=alpha*vec[features[i].feat_index]
00102 *features[i].entry;
00103 }
00104 }
00105
00106 return result;
00107 }
00108
00116 T sparse_dot(const SGSparseVector<T>& v)
00117 {
00118 return sparse_dot(*this, v);
00119 }
00120
00128 static T sparse_dot(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00129 {
00130 if (a.num_feat_entries == 0 || b.num_feat_entries == 0)
00131 return 0;
00132
00133 int32_t cmp = cmp_dot_prod_symmetry_fast(a.num_feat_entries, b.num_feat_entries);
00134
00135 if (cmp == 0)
00136 {
00137 return dot_prod_symmetric(a, b);
00138 }
00139 else if (cmp > 0)
00140 {
00141 return dot_prod_asymmetric(a, b);
00142 }
00143 else
00144 {
00145 return dot_prod_asymmetric(b, a);
00146 }
00147 }
00148
00149 protected:
00150
00151 virtual void copy_data(const SGReferencedData& orig)
00152 {
00153 num_feat_entries = ((SGSparseVector*)(&orig))->num_feat_entries;
00154 features = ((SGSparseVector*)(&orig))->features;
00155 }
00156
00157 virtual void init_data()
00158 {
00159 num_feat_entries = 0;
00160 features = NULL;
00161 }
00162
00163 virtual void free_data()
00164 {
00165 num_feat_entries = 0;
00166 SG_FREE(features);
00167 }
00168
00169 static int32_t floor_log(index_t n)
00170 {
00171 register int32_t i;
00172 for (i = 0; n != 0; i++)
00173 n >>= 1;
00174
00175 return i;
00176 }
00177
00178 static int32_t cmp_dot_prod_symmetry_fast(index_t alen, index_t blen)
00179 {
00180 if (alen > blen)
00181 {
00182 return (blen * floor_log(alen) < alen) ? -1 : 0;
00183 }
00184 else
00185 {
00186 return (alen * floor_log(blen) < blen) ? 1 : 0;
00187 }
00188 }
00189
00190 static T dot_prod_asymmetric(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00191 {
00192 T dot_prod = 0;
00193 for(index_t b_idx = 0; b_idx < b.num_feat_entries; ++b_idx)
00194 {
00195 const T tmp = b.features[b_idx].entry;
00196 if (a.features[a.num_feat_entries-1].feat_index < b.features[b_idx].feat_index)
00197 break;
00198 for (index_t a_idx = 0; a_idx < a.num_feat_entries; ++a_idx)
00199 {
00200 if (a.features[a_idx].feat_index == b.features[b_idx].feat_index)
00201 dot_prod += tmp * a.features[a_idx].entry;
00202 }
00203 }
00204 return dot_prod;
00205 }
00206
00207 static T dot_prod_symmetric(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00208 {
00209 ASSERT(a.num_feat_entries > 0 && b.num_feat_entries > 0);
00210 T dot_prod = 0;
00211 index_t a_idx = 0, b_idx = 0;
00212 while (true)
00213 {
00214 if (a.features[a_idx].feat_index == b.features[b_idx].feat_index)
00215 {
00216 dot_prod += a.features[a_idx].entry * b.features[b_idx].entry;
00217
00218 a_idx++;
00219 if (a.num_feat_entries == a_idx)
00220 break;
00221 b_idx++;
00222 if (b.num_feat_entries == b_idx)
00223 break;
00224 }
00225 else if (a.features[a_idx].feat_index < b.features[b_idx].feat_index)
00226 {
00227 a_idx++;
00228 if (a.num_feat_entries == a_idx)
00229 break;
00230 }
00231 else
00232 {
00233 b_idx++;
00234 if (b.num_feat_entries == b_idx)
00235 break;
00236 }
00237 }
00238 return dot_prod;
00239 }
00240
00241 public:
00243 index_t num_feat_entries;
00244
00246 SGSparseVectorEntry<T>* features;
00247
00248 };
00249
00250 }
00251
00252 #endif // __SGSPARSEVECTOR_H__