SGSparseVector.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Written (W) 2010,2012 Soeren Sonnenburg
00009  * Copyright (C) 2010 Berlin Institute of Technology
00010  * Copyright (C) 2012 Soeren Sonnenburg
00011  */
00012 
00013 #ifndef __SGSPARSEVECTOR_H__
00014 #define __SGSPARSEVECTOR_H__
00015 
00016 #include <shogun/lib/config.h>
00017 #include <shogun/lib/DataType.h>
00018 #include <shogun/lib/SGReferencedData.h>
00019 #include <map>
00020 
00021 namespace shogun
00022 {
00023 
00024 
00026 template <class T> struct SGSparseVectorEntry
00027 {
00029     index_t feat_index;
00031     T entry;
00032 };
00033 
00039 template <class T> class SGSparseVector : public SGReferencedData
00040 {
00041 public:
00043     SGSparseVector() : SGReferencedData()
00044     {
00045         init_data();
00046     }
00047 
00054     SGSparseVector(SGSparseVectorEntry<T>* feats, index_t num_entries,
00055             bool ref_counting=true) :
00056             SGReferencedData(ref_counting),
00057             num_feat_entries(num_entries), features(feats)
00058     {
00059     }
00060 
00062     SGSparseVector(index_t num_entries, bool ref_counting=true) :
00063         SGReferencedData(ref_counting),
00064         num_feat_entries(num_entries)
00065     {
00066         features = SG_MALLOC(SGSparseVectorEntry<T>, num_feat_entries);
00067     }
00068 
00070     SGSparseVector(const SGSparseVector& orig) :
00071         SGReferencedData(orig)
00072     {
00073         copy_data(orig);
00074     }
00075 
00076     virtual ~SGSparseVector()
00077     {
00078         unref();
00079     }
00080 
00092     T dense_dot(T alpha, T* vec, int32_t dim, T b)
00093     {
00094         ASSERT(vec);
00095         T result=b;
00096 
00097         if (features)
00098         {
00099             for (int32_t i=0; i<num_feat_entries; i++)
00100             {
00101                 result+=alpha*vec[features[i].feat_index]
00102                     *features[i].entry;
00103             }
00104         }
00105 
00106         return result;
00107     }
00108 
00116     T sparse_dot(const SGSparseVector<T>& v)
00117     {
00118         return sparse_dot(*this, v);
00119     }
00120 
00128     static T sparse_dot(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00129     {
00130         if (a.num_feat_entries == 0 || b.num_feat_entries == 0)
00131             return 0;
00132 
00133         int32_t cmp = cmp_dot_prod_symmetry_fast(a.num_feat_entries, b.num_feat_entries);
00134 
00135         if (cmp == 0) // symmetric
00136         {
00137             return dot_prod_symmetric(a, b);
00138         }
00139         else if (cmp > 0) // b has more element
00140         {
00141             return dot_prod_asymmetric(a, b);
00142         }
00143         else // a has more element
00144         {
00145             return dot_prod_asymmetric(b, a);
00146         }
00147     }
00148 
00149 protected:
00150 
00151     virtual void copy_data(const SGReferencedData& orig)
00152     {
00153         num_feat_entries = ((SGSparseVector*)(&orig))->num_feat_entries;
00154         features = ((SGSparseVector*)(&orig))->features;
00155     }
00156 
00157     virtual void init_data()
00158     {
00159         num_feat_entries = 0;
00160         features = NULL;
00161     }
00162 
00163     virtual void free_data()
00164     {
00165         num_feat_entries = 0;
00166         SG_FREE(features);
00167     }
00168 
00169     static int32_t floor_log(index_t n)
00170     {
00171         register int32_t i;
00172         for (i = 0; n != 0; i++)
00173             n >>= 1;
00174 
00175         return i;
00176     }
00177 
00178     static int32_t cmp_dot_prod_symmetry_fast(index_t alen, index_t blen)
00179     {
00180         if (alen > blen) // no need for floats here
00181         {
00182             return (blen * floor_log(alen) < alen) ? -1 : 0;
00183         }
00184         else // alen <= blen
00185         {
00186             return (alen * floor_log(blen) < blen) ? 1 : 0;
00187         }
00188     }
00189 
00190     static T dot_prod_asymmetric(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00191     {
00192         T dot_prod = 0;
00193         for(index_t b_idx = 0; b_idx < b.num_feat_entries; ++b_idx)
00194         {
00195             const T tmp = b.features[b_idx].entry;
00196             if (a.features[a.num_feat_entries-1].feat_index < b.features[b_idx].feat_index)
00197                 break;
00198             for (index_t a_idx = 0; a_idx < a.num_feat_entries; ++a_idx)
00199             {
00200                 if (a.features[a_idx].feat_index == b.features[b_idx].feat_index)
00201                     dot_prod += tmp * a.features[a_idx].entry;
00202             }
00203         }
00204         return dot_prod;
00205     }
00206 
00207     static T dot_prod_symmetric(const SGSparseVector<T>& a, const SGSparseVector<T>& b)
00208     {
00209         ASSERT(a.num_feat_entries > 0 && b.num_feat_entries > 0);
00210         T dot_prod = 0;
00211         index_t a_idx = 0, b_idx = 0;
00212         while (true)
00213         {
00214             if (a.features[a_idx].feat_index == b.features[b_idx].feat_index)
00215             {
00216                 dot_prod += a.features[a_idx].entry * b.features[b_idx].entry;
00217 
00218                 a_idx++;
00219                 if (a.num_feat_entries == a_idx)
00220                     break;
00221                 b_idx++;
00222                 if (b.num_feat_entries == b_idx)
00223                     break;
00224             }
00225             else if (a.features[a_idx].feat_index < b.features[b_idx].feat_index)
00226             {
00227                 a_idx++;
00228                 if (a.num_feat_entries == a_idx)
00229                     break;
00230             }
00231             else
00232             {
00233                 b_idx++;
00234                 if (b.num_feat_entries == b_idx)
00235                     break;
00236             }
00237         }
00238         return dot_prod;
00239     }
00240 
00241 public:
00243     index_t num_feat_entries;
00244 
00246     SGSparseVectorEntry<T>* features;
00247 
00248 };
00249 
00250 }
00251 
00252 #endif // __SGSPARSEVECTOR_H__
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation