00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/common.h>
00012 #include <shogun/kernel/CommUlongStringKernel.h>
00013 #include <shogun/kernel/SqrtDiagKernelNormalizer.h>
00014 #include <shogun/features/StringFeatures.h>
00015 #include <shogun/io/SGIO.h>
00016
00017 using namespace shogun;
00018
00019 CCommUlongStringKernel::CCommUlongStringKernel(int32_t size, bool us)
00020 : CStringKernel<uint64_t>(size), use_sign(us)
00021 {
00022 properties |= KP_LINADD;
00023 clear_normal();
00024
00025 set_normalizer(new CSqrtDiagKernelNormalizer());
00026 }
00027
00028 CCommUlongStringKernel::CCommUlongStringKernel(
00029 CStringFeatures<uint64_t>* l, CStringFeatures<uint64_t>* r, bool us,
00030 int32_t size)
00031 : CStringKernel<uint64_t>(size), use_sign(us)
00032 {
00033 properties |= KP_LINADD;
00034 clear_normal();
00035 set_normalizer(new CSqrtDiagKernelNormalizer());
00036 init(l,r);
00037 }
00038
00039 CCommUlongStringKernel::~CCommUlongStringKernel()
00040 {
00041 cleanup();
00042 }
00043
00044 void CCommUlongStringKernel::remove_lhs()
00045 {
00046 delete_optimization();
00047
00048 #ifdef SVMLIGHT
00049 if (lhs)
00050 cache_reset();
00051 #endif
00052
00053 lhs = NULL ;
00054 rhs = NULL ;
00055 }
00056
00057 void CCommUlongStringKernel::remove_rhs()
00058 {
00059 #ifdef SVMLIGHT
00060 if (rhs)
00061 cache_reset();
00062 #endif
00063
00064 rhs = lhs;
00065 }
00066
00067 bool CCommUlongStringKernel::init(CFeatures* l, CFeatures* r)
00068 {
00069 CStringKernel<uint64_t>::init(l,r);
00070 return init_normalizer();
00071 }
00072
00073 void CCommUlongStringKernel::cleanup()
00074 {
00075 delete_optimization();
00076 clear_normal();
00077 CKernel::cleanup();
00078 }
00079
00080 float64_t CCommUlongStringKernel::compute(int32_t idx_a, int32_t idx_b)
00081 {
00082 int32_t alen, blen;
00083 bool free_avec, free_bvec;
00084 uint64_t* avec=((CStringFeatures<uint64_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00085 uint64_t* bvec=((CStringFeatures<uint64_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00086
00087 float64_t result=0;
00088
00089 int32_t left_idx=0;
00090 int32_t right_idx=0;
00091
00092 if (use_sign)
00093 {
00094 while (left_idx < alen && right_idx < blen)
00095 {
00096 if (avec[left_idx]==bvec[right_idx])
00097 {
00098 uint64_t sym=avec[left_idx];
00099
00100 while (left_idx< alen && avec[left_idx]==sym)
00101 left_idx++;
00102
00103 while (right_idx< blen && bvec[right_idx]==sym)
00104 right_idx++;
00105
00106 result++;
00107 }
00108 else if (avec[left_idx]<bvec[right_idx])
00109 left_idx++;
00110 else
00111 right_idx++;
00112 }
00113 }
00114 else
00115 {
00116 while (left_idx < alen && right_idx < blen)
00117 {
00118 if (avec[left_idx]==bvec[right_idx])
00119 {
00120 int32_t old_left_idx=left_idx;
00121 int32_t old_right_idx=right_idx;
00122
00123 uint64_t sym=avec[left_idx];
00124
00125 while (left_idx< alen && avec[left_idx]==sym)
00126 left_idx++;
00127
00128 while (right_idx< blen && bvec[right_idx]==sym)
00129 right_idx++;
00130
00131 result+=((float64_t) (left_idx-old_left_idx)) * ((float64_t) (right_idx-old_right_idx));
00132 }
00133 else if (avec[left_idx]<bvec[right_idx])
00134 left_idx++;
00135 else
00136 right_idx++;
00137 }
00138 }
00139 ((CStringFeatures<uint64_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00140 ((CStringFeatures<uint64_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00141
00142 return result;
00143 }
00144
00145 void CCommUlongStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00146 {
00147 int32_t t=0;
00148 int32_t j=0;
00149 int32_t k=0;
00150 int32_t last_j=0;
00151 int32_t len=-1;
00152 bool free_vec;
00153 uint64_t* vec=((CStringFeatures<uint64_t>*) lhs)->get_feature_vector(vec_idx, len, free_vec);
00154
00155 if (vec && len>0)
00156 {
00157 uint64_t* dic= SG_MALLOC(uint64_t, len+dictionary.get_num_elements());
00158 float64_t* dic_weights= SG_MALLOC(float64_t, len+dictionary.get_num_elements());
00159
00160 if (use_sign)
00161 {
00162 for (j=1; j<len; j++)
00163 {
00164 if (vec[j]==vec[j-1])
00165 continue;
00166
00167 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx);
00168 }
00169
00170 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx);
00171
00172 while (k<dictionary.get_num_elements())
00173 {
00174 dic[t]=dictionary[k];
00175 dic_weights[t]=dictionary_weights[k];
00176 t++;
00177 k++;
00178 }
00179 }
00180 else
00181 {
00182 for (j=1; j<len; j++)
00183 {
00184 if (vec[j]==vec[j-1])
00185 continue;
00186
00187 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx);
00188 last_j = j;
00189 }
00190
00191 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx);
00192
00193 while (k<dictionary.get_num_elements())
00194 {
00195 dic[t]=dictionary[k];
00196 dic_weights[t]=dictionary_weights[k];
00197 t++;
00198 k++;
00199 }
00200 }
00201
00202 dictionary.set_array(dic, t, len+dictionary.get_num_elements());
00203 dictionary_weights.set_array(dic_weights, t, len+dictionary.get_num_elements());
00204 }
00205 ((CStringFeatures<uint64_t>*) lhs)->free_feature_vector(vec, vec_idx, free_vec);
00206
00207 set_is_initialized(true);
00208 }
00209
00210 void CCommUlongStringKernel::clear_normal()
00211 {
00212 dictionary.resize_array(0);
00213 dictionary_weights.resize_array(0);
00214 set_is_initialized(false);
00215 }
00216
00217 bool CCommUlongStringKernel::init_optimization(
00218 int32_t count, int32_t *IDX, float64_t * weights)
00219 {
00220 clear_normal();
00221
00222 if (count<=0)
00223 {
00224 set_is_initialized(true);
00225 SG_DEBUG( "empty set of SVs\n");
00226 return true;
00227 }
00228
00229 SG_DEBUG( "initializing CCommUlongStringKernel optimization\n");
00230
00231 for (int32_t i=0; i<count; i++)
00232 {
00233 if ( (i % (count/10+1)) == 0)
00234 SG_PROGRESS(i, 0, count);
00235
00236 add_to_normal(IDX[i], weights[i]);
00237 }
00238
00239 SG_PRINT( "Done. \n");
00240
00241 set_is_initialized(true);
00242 return true;
00243 }
00244
00245 bool CCommUlongStringKernel::delete_optimization()
00246 {
00247 SG_DEBUG( "deleting CCommUlongStringKernel optimization\n");
00248 clear_normal();
00249 return true;
00250 }
00251
00252
00253
00254 float64_t CCommUlongStringKernel::compute_optimized(int32_t i)
00255 {
00256 float64_t result = 0;
00257 int32_t j, last_j=0;
00258 int32_t old_idx = 0;
00259
00260 if (!get_is_initialized())
00261 {
00262 SG_ERROR( "CCommUlongStringKernel optimization not initialized\n");
00263 return 0 ;
00264 }
00265
00266
00267
00268 int32_t alen = -1;
00269 bool free_avec;
00270 uint64_t* avec=((CStringFeatures<uint64_t>*) rhs)->
00271 get_feature_vector(i, alen, free_avec);
00272
00273 if (avec && alen>0)
00274 {
00275 if (use_sign)
00276 {
00277 for (j=1; j<alen; j++)
00278 {
00279 if (avec[j]==avec[j-1])
00280 continue;
00281
00282 int32_t idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00283
00284 if (idx!=-1)
00285 {
00286 if (dictionary[idx+old_idx] == avec[j-1])
00287 result += dictionary_weights[idx+old_idx];
00288
00289 old_idx+=idx;
00290 }
00291 }
00292
00293 int32_t idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00294 if (idx!=-1)
00295 result += dictionary_weights[idx+old_idx];
00296 }
00297 else
00298 {
00299 for (j=1; j<alen; j++)
00300 {
00301 if (avec[j]==avec[j-1])
00302 continue;
00303
00304 int32_t idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00305
00306 if (idx!=-1)
00307 {
00308 if (dictionary[idx+old_idx] == avec[j-1])
00309 result += dictionary_weights[idx+old_idx]*(j-last_j);
00310
00311 old_idx+=idx;
00312 }
00313
00314 last_j = j;
00315 }
00316
00317 int32_t idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00318 if (idx!=-1)
00319 result += dictionary_weights[idx+old_idx]*(alen-last_j);
00320 }
00321 }
00322
00323 ((CStringFeatures<uint64_t>*) rhs)->free_feature_vector(avec, i, free_avec);
00324
00325 return normalizer->normalize_rhs(result, i);
00326 }