00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00014
00015 #include <shogun/lib/common.h>
00016 #include <shogun/kernel/StringKernel.h>
00017 #include <shogun/kernel/WeightedDegreeStringKernel.h>
00018 #include <shogun/lib/Trie.h>
00019
00020 namespace shogun
00021 {
00022
00023 class CSVM;
00024
00048 class CWeightedDegreePositionStringKernel: public CStringKernel<char>
00049 {
00050 public:
00052 CWeightedDegreePositionStringKernel();
00053
00061 CWeightedDegreePositionStringKernel(
00062 int32_t size, int32_t degree,
00063 int32_t max_mismatch=0, int32_t mkl_stepsize=1);
00064
00075 CWeightedDegreePositionStringKernel(
00076 int32_t size, float64_t* weights, int32_t degree,
00077 int32_t max_mismatch, int32_t* shift, int32_t shift_len,
00078 int32_t mkl_stepsize=1);
00079
00086 CWeightedDegreePositionStringKernel(
00087 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree);
00088
00089 virtual ~CWeightedDegreePositionStringKernel();
00090
00097 virtual bool init(CFeatures* l, CFeatures* r);
00098
00100 virtual void cleanup();
00101
00106 virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; }
00107
00112 virtual const char* get_name() const { return "WeightedDegreePositionStringKernel"; }
00113
00121 inline virtual bool init_optimization(
00122 int32_t p_count, int32_t *IDX, float64_t * alphas)
00123 {
00124 return init_optimization(p_count, IDX, alphas, -1);
00125 }
00126
00138 virtual bool init_optimization(
00139 int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num,
00140 int32_t upto_tree=-1);
00141
00146 virtual bool delete_optimization();
00147
00153 inline virtual float64_t compute_optimized(int32_t idx)
00154 {
00155 ASSERT(get_is_initialized());
00156 ASSERT(alphabet);
00157 ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA);
00158 return compute_by_tree(idx);
00159 }
00160
00165 static void* compute_batch_helper(void* p);
00166
00177 virtual void compute_batch(
00178 int32_t num_vec, int32_t* vec_idx, float64_t* target,
00179 int32_t num_suppvec, int32_t* IDX, float64_t* alphas,
00180 float64_t factor=1.0);
00181
00185 inline virtual void clear_normal()
00186 {
00187 if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes()))
00188 {
00189 tries.set_use_compact_terminal_nodes(false) ;
00190 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ;
00191 }
00192
00193 if (get_is_initialized())
00194 {
00195 if (opt_type==SLOWBUTMEMEFFICIENT)
00196 tries.delete_trees(true);
00197 else if (opt_type==FASTBUTMEMHUNGRY)
00198 tries.delete_trees(false);
00199 else
00200 SG_ERROR( "unknown optimization type\n");
00201
00202 set_is_initialized(false);
00203 }
00204 }
00205
00211 inline virtual void add_to_normal(int32_t idx, float64_t weight)
00212 {
00213 add_example_to_tree(idx, weight);
00214 set_is_initialized(true);
00215 }
00216
00221 inline virtual int32_t get_num_subkernels()
00222 {
00223 if (position_weights!=NULL)
00224 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ;
00225 if (length==0)
00226 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize);
00227 return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ;
00228 }
00229
00235 inline void compute_by_subkernel(
00236 int32_t idx, float64_t * subkernel_contrib)
00237 {
00238 if (get_is_initialized())
00239 {
00240 compute_by_tree(idx, subkernel_contrib);
00241 return ;
00242 }
00243
00244 SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ;
00245 }
00246
00252 inline const float64_t* get_subkernel_weights(int32_t& num_weights)
00253 {
00254 num_weights = get_num_subkernels() ;
00255
00256 SG_FREE(weights_buffer);
00257 weights_buffer = SG_MALLOC(float64_t, num_weights);
00258
00259 if (position_weights!=NULL)
00260 for (int32_t i=0; i<num_weights; i++)
00261 weights_buffer[i] = position_weights[i*mkl_stepsize] ;
00262 else
00263 for (int32_t i=0; i<num_weights; i++)
00264 weights_buffer[i] = weights[i*mkl_stepsize] ;
00265
00266 return weights_buffer ;
00267 }
00268
00274 virtual void set_subkernel_weights(SGVector<float64_t> w)
00275 {
00276 float64_t* weights2=w.vector;
00277 int32_t num_weights2=w.vlen;
00278
00279 int32_t num_weights = get_num_subkernels() ;
00280 if (num_weights!=num_weights2)
00281 SG_ERROR( "number of weights do not match\n") ;
00282
00283 if (position_weights!=NULL)
00284 for (int32_t i=0; i<num_weights; i++)
00285 for (int32_t j=0; j<mkl_stepsize; j++)
00286 {
00287 if (i*mkl_stepsize+j<seq_length)
00288 position_weights[i*mkl_stepsize+j] = weights2[i] ;
00289 }
00290 else if (length==0)
00291 {
00292 for (int32_t i=0; i<num_weights; i++)
00293 for (int32_t j=0; j<mkl_stepsize; j++)
00294 if (i*mkl_stepsize+j<get_degree())
00295 weights[i*mkl_stepsize+j] = weights2[i] ;
00296 }
00297 else
00298 {
00299 for (int32_t i=0; i<num_weights; i++)
00300 for (int32_t j=0; j<mkl_stepsize; j++)
00301 if (i*mkl_stepsize+j<get_degree()*length)
00302 weights[i*mkl_stepsize+j] = weights2[i] ;
00303 }
00304 }
00305
00306
00312 float64_t* compute_abs_weights(int32_t & len);
00313
00318 bool is_tree_initialized() { return tree_initialized; }
00319
00324 inline int32_t get_max_mismatch() { return max_mismatch; }
00325
00330 inline int32_t get_degree() { return degree; }
00331
00337 inline float64_t *get_degree_weights(int32_t& d, int32_t& len)
00338 {
00339 d=degree;
00340 len=length;
00341 return weights;
00342 }
00343
00349 inline float64_t *get_weights(int32_t& num_weights)
00350 {
00351 if (position_weights!=NULL)
00352 {
00353 num_weights = seq_length ;
00354 return position_weights ;
00355 }
00356 if (length==0)
00357 num_weights = degree ;
00358 else
00359 num_weights = degree*length ;
00360 return weights;
00361 }
00362
00368 inline float64_t *get_position_weights(int32_t& len)
00369 {
00370 len=seq_length;
00371 return position_weights;
00372 }
00373
00378 void set_shifts(SGVector<int32_t> shifts);
00379
00384 bool set_weights(SGMatrix<float64_t> new_weights);
00385
00390 virtual bool set_wd_weights();
00391
00397 virtual void set_position_weights(SGVector<float64_t> pws);
00398
00406 bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num);
00407
00415 bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num);
00416
00421 bool init_block_weights();
00422
00427 bool init_block_weights_from_wd();
00428
00433 bool init_block_weights_from_wd_external();
00434
00439 bool init_block_weights_const();
00440
00445 bool init_block_weights_linear();
00446
00451 bool init_block_weights_sqpoly();
00452
00457 bool init_block_weights_cubicpoly();
00458
00463 bool init_block_weights_exp();
00464
00469 bool init_block_weights_log();
00470
00475 bool delete_position_weights()
00476 {
00477 SG_FREE(position_weights);
00478 position_weights=NULL;
00479 return true;
00480 }
00481
00486 bool delete_position_weights_lhs()
00487 {
00488 SG_FREE(position_weights_lhs);
00489 position_weights_lhs=NULL;
00490 return true;
00491 }
00492
00497 bool delete_position_weights_rhs()
00498 {
00499 SG_FREE(position_weights_rhs);
00500 position_weights_rhs=NULL;
00501 return true;
00502 }
00503
00509 virtual float64_t compute_by_tree(int32_t idx);
00510
00516 virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib);
00517
00530 float64_t* compute_scoring(
00531 int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00532 float64_t* target, int32_t num_suppvec, int32_t* IDX,
00533 float64_t* weights);
00534
00543 char* compute_consensus(
00544 int32_t &num_feat, int32_t num_suppvec, int32_t* IDX,
00545 float64_t* alphas);
00546
00558 float64_t* extract_w(
00559 int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00560 float64_t* w_result, int32_t num_suppvec, int32_t* IDX,
00561 float64_t* alphas);
00562
00575 float64_t* compute_POIM(
00576 int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00577 float64_t* poim_result, int32_t num_suppvec, int32_t* IDX,
00578 float64_t* alphas, float64_t* distrib);
00579
00586 void prepare_POIM2(
00587 float64_t* distrib, int32_t num_sym, int32_t num_feat);
00588
00595 void compute_POIM2(int32_t max_degree, CSVM* svm);
00596
00602 void get_POIM2(float64_t** poim, int32_t* result_len);
00603
00605 void cleanup_POIM2();
00606
00607 protected:
00609 void create_empty_tries();
00610
00616 virtual void add_example_to_tree(
00617 int32_t idx, float64_t weight);
00618
00625 void add_example_to_single_tree(
00626 int32_t idx, float64_t weight, int32_t tree_num);
00627
00636 virtual float64_t compute(int32_t idx_a, int32_t idx_b);
00637
00646 float64_t compute_with_mismatch(
00647 char* avec, int32_t alen, char* bvec, int32_t blen);
00648
00657 float64_t compute_without_mismatch(
00658 char* avec, int32_t alen, char* bvec, int32_t blen);
00659
00668 float64_t compute_without_mismatch_matrix(
00669 char* avec, int32_t alen, char* bvec, int32_t blen);
00670
00681 float64_t compute_without_mismatch_position_weights(
00682 char* avec, float64_t *posweights_lhs, int32_t alen,
00683 char* bvec, float64_t *posweights_rhs, int32_t blen);
00684
00686 virtual void remove_lhs();
00687
00696 virtual void load_serializable_post() throw (ShogunException);
00697
00698 private:
00701 void init();
00702
00703 protected:
00705 float64_t* weights;
00707 int32_t weights_degree;
00709 int32_t weights_length;
00710
00712 float64_t* position_weights;
00714 int32_t position_weights_len;
00715
00717 float64_t* position_weights_lhs;
00719 int32_t position_weights_lhs_len;
00721 float64_t* position_weights_rhs;
00723 int32_t position_weights_rhs_len;
00725 bool* position_mask;
00726
00728 float64_t* weights_buffer;
00730 int32_t mkl_stepsize;
00731
00733 int32_t degree;
00735 int32_t length;
00736
00738 int32_t max_mismatch;
00740 int32_t seq_length;
00741
00743 int32_t *shift;
00745 int32_t shift_len;
00747 int32_t max_shift;
00748
00750 bool block_computation;
00751
00753 float64_t* block_weights;
00755 EWDKernType type;
00757 int32_t which_degree;
00758
00760 CTrie<DNATrie> tries;
00762 CTrie<POIMTrie> poim_tries;
00763
00765 bool tree_initialized;
00767 bool use_poim_tries;
00768
00770 float64_t* m_poim_distrib;
00772 float64_t* m_poim;
00773
00775 int32_t m_poim_num_sym;
00777 int32_t m_poim_num_feat;
00779 int32_t m_poim_result_len;
00780
00782 CAlphabet* alphabet;
00783 };
00784 }
00785 #endif